Compare commits

..

55 Commits

Author SHA1 Message Date
Diego Devesa b617cfd289 ggml-alloc : fix leak when reusing a tensor with a larger size (#16679) 2025-10-20 14:53:50 +02:00
Aleksander Grygier 79068501fa Prevent premature submission on IME input (#16673)
* fix: Prevent premature submission on IME input

* chore: update webui static build

* refactor: Put IME completion checker in a helper function and add checking for `KeyboardEvent.eventKey === 229`

* chore: update webui static build

* chore: update webui static build

* chore: update webui static build
2025-10-20 14:21:12 +02:00
Aleksander Grygier 0e4a0cf2fa Import/Export UX improvements (#16619)
* webui : added download action (#13552)

* webui : import and export (for all conversations)

* webui : fixed download-format, import of one conversation

* webui : add ExportedConversations type for chat import/export

* feat: Update naming & order

* chore: Linting

* feat: Import/Export UX improvements

* chore: update webui build output

* feat: Update UI placement of Import/Export tab in Chat Settings Dialog

* refactor: Cleanup

chore: update webui build output

* feat: Enable shift-click multiple conversation items selection

* chore: update webui static build

* chore: update webui static build

---------

Co-authored-by: Sascha Rogmann <github@rogmann.org>
2025-10-20 13:29:14 +02:00
Aleksander Grygier 13f2cfad41 Enable per-conversation loading states to allow having parallel conversations (#16327)
* feat: Per-conversation loading states and tracking streaming stats

* chore: update webui build output

* refactor: Chat state management

Consolidates loading state management by using a global `isLoading` store synchronized with individual conversation states.

This change ensures proper reactivity and avoids potential race conditions when updating the UI based on the loading status of different conversations. It also improves the accuracy of statistics displayed.

Additionally, slots service methods are updated to use conversation IDs for per-conversation state management, avoiding global state pollution.

* feat: Adds loading indicator to conversation items

* chore: update webui build output

* fix: Fix aborting chat streaming

Improves the chat stream abortion process by ensuring that partial responses are saved before the abort signal is sent.

This avoids a race condition where the onError callback could clear the streaming state before the partial response is saved. Additionally, the stream reading loop and callbacks are now checked for abort signals to prevent further processing after abortion.

* refactor: Remove redundant comments

* chore: build webui static output

* refactor: Cleanup

* chore: update webui build output

* chore: update webui build output

* fix: Conversation loading indicator for regenerating messages

* chore: update webui static build

* feat: Improve configuration

* feat: Install `http-server` as dev dependency to not need to rely on `npx` in CI
2025-10-20 12:41:13 +02:00
takuya kodama 06332e2867 llama-batch: fix build fails with -Werror=missing-braces (#16614)
## Why it failed

When compiling with strict compiler flags (-Wmissing-braces -Werror=missing-braces),
the build fails with the following error:

```
cmake \
  -S . \
  -B ../llama.cpp.build \
  --preset=x64-linux-gcc-debug \
  -DCMAKE_INSTALL_PREFIX=/tmp/local \
  -DCMAKE_CXX_FLAGS="-Wmissing-braces -Werror=missing-braces" && \
cmake --build ../llama.cpp.build/
...
In file included from /home/otegami/work/cpp/llama.cpp/src/llama-graph.h:4,
                 from /home/otegami/work/cpp/llama.cpp/src/llama-model.h:5,
                 from /home/otegami/work/cpp/llama.cpp/src/llama.cpp:8:
/home/otegami/work/cpp/llama.cpp/src/llama-batch.h:126:48: error: missing braces around initializer for 'std::__array_traits<int, 1>::_Type' {aka 'int [1]'} [-Werror=missing-braces]
  126 |     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
      |                                                ^
cc1plus: some warnings being treated as errors
```

The issue is that std::array initialization requires double braces.

## How to fix

This PR changes `{ 0 }` to `{{ 0 }}` for std::array initialization.

This is part of a series of commits to fix missing braces warnings across the codebase.
- src/llama-batch.h <- This PR is here.
- src/llama-context.cpp
- tests/test-backend-ops.cpp
- tests/test-gguf.cpp
- tools/mtmd/clip.cpp

Benefits:
- std::array is a struct containing a C-style array, requiring nested braces
- Enables stricter compiler warnings to catch potential issues
2025-10-20 11:27:09 +03:00
Ron Evans 72d53e6c4d readme: update bindings (#16651)
Signed-off-by: deadprogram <ron@hybridgroup.com>
2025-10-20 11:20:04 +03:00
safranowith 2330de7b84 SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators (#16613)
* SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators

Clean up unrelated changes from previous commit

* Chore: remove empty lines and fix indentation

* Clean up: remove leftover blank lines and fix spacing

* chore: fix trailing whitespace and ensure final newline

* Cleanup: remove redundant declarations already defined in header

* Sync docs/ops.md with updated backend operation support

* docs: update ops.md after rebase

* docs: update ops.md - Vulkan supports SSM_CONV and SSM_SCAN
2025-10-20 11:08:32 +03:00
takuya kodama 7062dd8460 llama-context: only warn on pooling_type when user specified (#16674)
The unexpeced pooling_type warning was incorrectly shown when users did not
specify the --pooling-type parameter. In this case, the parameter
defaults to `LLAMA_POOLING_TYPE_UNSPECIFIED (-1)`, and the code
automatically applies the model's default pooling type.

Example of spurious warning:
```
$ llama-embedding -hf ggml-org/bge-m3-Q8_0-GGUF -p "hello"
...
llama_init_from_model: model default pooling_type is [2], but [-1] was specified
...
```

This fix ensures the warning only appears when users explicitly specify
a pooling type that differs from the model's default (e.g., using
--pooling-type mean on a model that expects CLS pooling).
2025-10-20 10:44:21 +03:00
Giuseppe Scrivano 0398752dd4 model : add Granite Hybrid types (#16635)
add Granite 4 models mapping their embedding dimensions to the # of
parameters.

Information taken from https://huggingface.co/ibm-granite/granite-4.0-h-tiny

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
2025-10-19 23:54:31 +02:00
Aaron Teo 4f73d0a951 ci : fix binaries release failure for s390x (binaries may not work yet) (#16664)
* devops: initial patch

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* devops: forgot the z15 suffix

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* devops: attempt at impl GGML_CPU_ALL_VARIANTS for s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* devops: rm baseline version

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

---------

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2025-10-19 23:06:39 +02:00
Sigbjørn Skjæret cec5edbcae ci : avoid manual updates of docs/ops.md (#16663) 2025-10-19 14:03:25 +02:00
Aaron Teo fcb235b466 ci: include s390x release binaries (#16648)
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2025-10-19 18:37:47 +08:00
Aman Gupta 55754bebd5 CODEOWNERS: update for ggml-cuda/mmf (#16660) 2025-10-19 10:37:12 +03:00
Johannes Gäßler ee09828cb0 HIP: fix GPU_TARGETS (#16642) 2025-10-18 14:47:32 +02:00
Jeff Bolz e56abd2098 vulkan: Implement topk_moe fused shader, ported from CUDA (#16641)
This is similar to the CUDA shader from #16130, but doesn't use shared memory
and handles different subgroup sizes.
2025-10-18 12:22:57 +02:00
Aman Gupta 38355c6c8e CUDA: use registers instead of smem in topk-moe (#16647)
Uses the technique used in the vulkan PR #16641. Neat trick!
2025-10-18 11:52:53 +02:00
Shawn Gu 81387858f1 opencl: transposed gemm/gemv moe kernel with mxfp4,f32 (#16602)
* opencl: transposed gemm/gemv moe kernel with mxfp4,f32

* add restore kernel for moe transpose

* fix trailing whitespaces

* resolve compilation warnings
2025-10-17 17:55:32 -07:00
Johannes Gäßler 66b0dbcb2d llama-model: fix insonsistent ctxs <-> bufs order (#16581) 2025-10-17 17:41:09 +02:00
Radoslav Gerganov 41386cf365 rpc : report actual free memory (#16616)
* rpc : report actual free memory

Start reporting the free memory on every device instead of using
fixed values. Now llama-cli users can get a nice memory breakdown
when using RPC devices.

* drop --mem in rpc-server
2025-10-17 18:02:52 +03:00
Giuseppe Scrivano 3d4e86bbeb vulkan: Add State Space Model (SSM) Operations Support (#16463)
* vulkan: implement SSM scan operation

Add State Space Model scan operation to the Vulkan backend.

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>

* vulkan: implement SSM conv operation

Add State Space Model conv operation to the Vulkan backend.

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>

---------

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
2025-10-17 14:23:47 +02:00
muggle-stack 342c728d03 ggml : fix SpaceMit IME array out-of-bounds in task assignment (#16629)
Fix incorrect task-to-batch index calculation in the quantization phase.

The bug caused out-of-bounds access to qnbitgemm_args array when
compute_idx exceeded per_gemm_block_count_m, leading to invalid
pointer dereferences and SIGBUS errors.

Correctly map tasks to batches by dividing compute_idx by
per_gemm_block_count_m instead of block_size_m.

Example:
  batch_feature=1, gemm_m=30, block_size_m=4
  per_gemm_block_count_m = 8, task_count = 8

  Old: gemm_idx = 4/4 = 1 (out of bounds  New: gemm_idx = 4/8 = 0 (correct)

Tested on SpaceMit K1 RISC-V64 with qwen2.5:0.5b model.

Co-authored-by: muggle <mingjun.rong@spacemit.com>
2025-10-17 13:01:23 +03:00
Pascal ababae7e1e webui: reorganize settings layout (#16607)
* webui: reorganize settings layout

* chore: update webui build output

* fix: remove unused variable

* chore: update webui build output
2025-10-17 10:35:03 +02:00
Jeff Bolz b19491599d vulkan: fix debug build (add_rms_len/data not found) (#16624) 2025-10-17 09:31:04 +02:00
Ilia Ilmer 9ad4f1931e metal : add CONV_TRANSPOSE_2D (#16542)
* initial: headers and metal-device.cpp updates

* adding conv_transpose_2d

* fix type

* fix type: int32->int64

* Update ggml/src/ggml-metal/ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml-metal/ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml-metal/ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* add checks for src[0] and src[1]; add type checks

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* add more tests, add optimization to threading

* add dynamic memory allocation in metal

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-17 09:33:58 +03:00
Olivier Chafik 79967ec596 grammar : use int64_t to avoid int overflows in int schema to grammar conversion logic (#16626) 2025-10-17 08:59:31 +03:00
GittyBurstein ceff6bb253 SYCL SET operator optimized for F32 tensors (#16350)
* SYCL/SET: implement operator + wire-up; docs/ops updates; element_wise & ggml-sycl changes

* sycl(SET): re-apply post-rebase; revert manual docs/ops.md; style cleanups

* move SET op to standalone file, GPU-only implementation

* Update SYCL SET operator for F32

* ci: fix editorconfig issues (LF endings, trailing spaces, final newline)

* fixed ggml-sycl.cpp

---------

Co-authored-by: Gitty Burstein <gitty@example.com>
2025-10-17 10:36:40 +08:00
Xuan-Son Nguyen 1bb4f43380 mtmd : support home-cooked Mistral Small Omni (#14928) 2025-10-16 19:00:31 +02:00
Pascal 683fa6ba4e fix: added a normalization step for MathJax-style \[\] and \(\) delimiters (#16599)
* fix: added a normalization step for MathJax-style \[\] and \(\) delimiters

So inline and block equations are converted before KaTeX rendering,
enabling proper display of model-generated LaTeX in the WebUI

* chore: update webui build output
2025-10-16 16:28:41 +02:00
GittyBurstein b22572e97d sycl : add ARANGE operator (#16362)
* SYCL: update element-wise ops and presets

* clean arange

* Re-trigger CI

---------

Co-authored-by: Gitty Burstein <gitty@example.com>
2025-10-16 15:26:21 +02:00
Chenguang Li 7a50cf388a CANN: format code using .clang-format (#15863)
This commit applies .clang-format rules to all source files under the
ggml-cann directory to ensure consistent coding style and readability.
The .clang-format option `SortIncludes: false` has been set to disable
automatic reordering of include directives.
No functional changes are introduced.

Co-authored-by: hipudding <huafengchun@gmail.com>
2025-10-16 16:41:11 +08:00
takasurazeem 6f5d924637 common : Update the docs on -t --threads (#16236)
* Update the docs on -t --threads

* Revert "Update the docs on -t --threads"

This reverts commit eba97345e2.

* docs: clarify -t/--threads parameter uses CPU threads and defaults to all available cores

* Update arg.cpp
2025-10-16 08:11:33 +03:00
takuya kodama adc9b60f19 ggml-cpu: replace putenv with setenv for const-correctness (#16573)
## Why it failed

When compiling with strict compiler flags (-Wwrite-strings -Werror=discarded-qualifiers),
the build fails with the following error:

```
cmake \
  -S . \
  -B ../llama.cpp.build \
  --preset=x64-linux-gcc-debug \
  -DCMAKE_INSTALL_PREFIX=/tmp/local \
  -DCMAKE_C_FLAGS="-Wwrite-strings -Werror=discarded-qualifiers" && \
cmake --build ../llama.cpp.build/
...
/home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c: In function ‘ggml_cpu_init’:
/home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c:3572:24: error: passing argument 1 of ‘putenv’ discards ‘const’ qualifier from pointer target type [-Werror=discarded-qualifiers]
 3572 |                 putenv("KMP_BLOCKTIME=200"); // 200ms
      |                        ^~~~~~~~~~~~~~~~~~~
In file included from /home/otegami/work/cpp/llama.cpp/ggml/src/./ggml-impl.h:10,
                 from /home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h:6,
                 from /home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h:3,
                 from /home/otegami/work/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c:6:
/usr/include/stdlib.h:786:26: note: expected ‘char *’ but argument is of type ‘const char *’
  786 | extern int putenv (char *__string) __THROW __nonnull ((1));
      |                    ~~~~~~^~~~~~~~
cc1: some warnings being treated as errors
ninja: build stopped: subcommand failed.
```

The issue is that putenv() expects a non-const char * but receives a string literal (const char *).

## How to fix

This PR replaces putenv("KMP_BLOCKTIME=200") with setenv("KMP_BLOCKTIME", "200", 0).

Benefits of setenv():
- Accepts const char * parameters (no qualifier warnings)
- Makes copies of the strings (safer memory handling)
- The third parameter (0) ensures we don't overwrite if already set
2025-10-16 08:10:32 +03:00
yael-works ee50ee1ead SYCL: Add GGML_OP_MEAN operator support (#16009)
* SYCL: Add GGML_OP_MEAN operator support

* SYCL: Fix formatting for GGML_OP_MEAN case

* Update ggml/src/ggml-sycl/ggml-sycl.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-10-16 12:21:28 +08:00
Aleksei Nikiforov 7adc79c032 gguf-py : add support for endian conversion of BF16 data (#16594)
BF16 requires special handling in this script
while it's a 2-bytes data, but view is 1-byte by default.
Switch to correct view before attempting byteswapping.

With this change correctly byteswapping models like
Meta-Llama-3-8B-Instruct-bf16-GGUF
should be possible.
2025-10-15 22:43:08 +02:00
safranowith 466c1911ab cpu : add FLOOR, CEIL, ROUND and TRUNC unary operators (#16083)
* CPU: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators

- Added the operators to unary op enum
- Implemented API functions
- Implemented forward and unary-op logic in CPU backend
- Updated ggml_get_n_tasks
- Updated operators names array and static_assert
- Updated docs and enabled automatic tests

* docs: add documentation for ggml_trunc and ggml_trunc_inplace in ggml.h

* chore: remove trailing whitespace from ggml.h

* Remove unresolved merge markers

* Apply review suggestions: cleanup formatting, enum order and leftover artifacts

* Regenerate ops.md using create_ops_docs.py
2025-10-15 21:24:51 +02:00
lhez 0cb7a0683b opencl: add q8_0 mm support (#16469)
* opencl: add mm_q8_0_f32

* opencl: fix data loading for incomplete tile

* opencl: use q8_0 mm for larger matrix

* opencl: add some tests to cover the path
2025-10-15 10:51:04 -07:00
lhez d93f8439b0 opencl: fix FA for f32 (#16584) 2025-10-15 10:48:28 -07:00
Aleksander Grygier f9fb33f263 Add server-driven parameter defaults and syncing (#16515) 2025-10-15 16:22:20 +02:00
Sam/Samuel f4ce81c45e metal: optimise GGML_OP_SUM (#16559)
* optimise GGML_OP_SUM

* add non-contiguous tests by permuting the input

* change tests to require full contiguity of OP_SUM

* cuda : add check GGML_OP_SUM

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-15 17:05:56 +03:00
Georgi Gerganov 17304cbcc1 server : fix img token logs (#16595) 2025-10-15 16:53:12 +03:00
Xuan-Son Nguyen 3e3cb19f64 llama-quant: add support for mmproj (#16592)
* llama-quant: add support for mmproj

* Update src/llama.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* check prefix instead

* small fix

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-15 14:48:08 +02:00
Julius Tischbein 5acd455460 CUDA: Changing the CUDA scheduling strategy to spin (#16585)
* CUDA set scheduling strategy to spinning for cc121

* Using prop.major and prop.minor, include HIP and MUSA

* Exclude HIP and MUSA

* Remove trailing whitespace

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Remove empty line

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-10-15 14:54:15 +03:00
Georgi Gerganov 554fd578a5 server : fix mtmd checkpoints (#16591) 2025-10-15 11:51:27 +02:00
Georgi Gerganov fa882fd2b1 metal : avoid using Metal's gpuAddress property (#16576)
* metal : avoid using Metal's gpuAddress property

* metal : fix rope kernels buffer check
2025-10-14 20:33:05 +03:00
SavicStefan ffa059034c vulkan: Add ACC_TYPE_VEC2 implementation (#16203)
Signed-off-by: Stefan Savic <stefan.savic@huawei.com>
Co-authored-by: Stefan Savic <stefan.savic@huawei.com>
2025-10-14 19:18:05 +02:00
Aman Gupta 120bf7046d CUDA + openCL: fix bug in accessing rms_norm->src while doing fusion (#16577) 2025-10-14 07:48:08 -07:00
Jeff Bolz 4258e0cfe7 vulkan: Support FA with K/V in F32 (#16543) 2025-10-14 15:53:37 +02:00
Jeff Bolz 7ea15bb64c vulkan: Improve build time for MSVC (#16545)
Enable CMP0147 so custom build steps (invoking vulkan-shader-gen) are run in parallel.

Enable /MP so source files are compiled in parallel.
2025-10-14 14:51:36 +02:00
Johannes Gäßler 9c7185dd28 CUDA: enable FA for FP32 KV cache (#16546) 2025-10-14 14:22:47 +02:00
Aman Gupta 1ee9d0b415 CUDA: use fastdiv + ggml_cuda_mad for mmvf (#16557)
* CUDA: use fastdiv + ggml_cuda_mad for mmvf

* use bf16 directly + fix formatting

* Add exception for HIP code
2025-10-14 13:16:21 +02:00
Aman Gupta 48e2fa9fb7 CUDA: add fp kernel for larger batch size MoE (#16512)
* CUDA: kernel for larger batch sizes for MoE

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* fixup

* tests

* Move mmq_ids_helper to mmid

* cleanup

* Remove redundant checks
2025-10-14 13:15:15 +02:00
Anav Prasad 5b6913c47b cuda : remove legacy copy-op pointer indirection code (#16485)
* remove legacy copy-op pointer indirection code

* further removal of copy-op indirection code

* renamed check_node_graph_compatibility_and_refresh_copy_ops function
2025-10-14 11:53:49 +02:00
Georgi Gerganov bc07349a7f server : dynamic token limit for prompt cache (#16560)
* server : dynamic token limit for prompt cache

* cont : print estimated token limit
2025-10-14 08:48:50 +03:00
Georgi Gerganov e60f241eac metal : FA support F32 K and V and head size = 32 (#16531)
* metal : FA support F32 K and V and head size = 32

* graph : remove obsolete comment [no ci]
2025-10-13 23:07:57 +03:00
Georgi Gerganov e38b7c6e9e graph : support cacheless embeddings with FA and iSWA (#16528)
* graph : support cacheless embeddings with FA and iSWA

* cont : deduplicate mask creation

* cont : fix name
2025-10-13 22:42:37 +03:00
127 changed files with 8290 additions and 3560 deletions
+2
View File
@@ -134,6 +134,8 @@ jobs:
include:
- build: 'x64'
os: ubuntu-22.04
- build: 's390x-z15' # z15 because our CI runners are on z15
os: ubuntu-22.04-s390x
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
# - build: 'arm64'
# os: ubuntu-22.04-arm
+2
View File
@@ -3,10 +3,12 @@ name: Update Operations Documentation
on:
push:
paths:
- 'docs/ops.md'
- 'docs/ops/**'
- 'scripts/create_ops_docs.py'
pull_request:
paths:
- 'docs/ops.md'
- 'docs/ops/**'
- 'scripts/create_ops_docs.py'
+1 -1
View File
@@ -55,7 +55,7 @@
/ggml/src/ggml-cuda/common.cuh @slaren
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
+1
View File
@@ -187,6 +187,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift)
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
</details>
+1 -1
View File
@@ -75,7 +75,7 @@ if [ ! -z ${GG_BUILD_ROCM} ]; then
exit 1
fi
CMAKE_EXTRA="${CMAKE_EXTRA} -DAMDGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
CMAKE_EXTRA="${CMAKE_EXTRA} -DGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
fi
if [ ! -z ${GG_BUILD_SYCL} ]; then
+1 -1
View File
@@ -1760,7 +1760,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
add_opt(common_arg(
{"-t", "--threads"}, "N",
string_format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads),
string_format("number of CPU threads to use during generation (default: %d)", params.cpuparams.n_threads),
[](common_params & params, int value) {
params.cpuparams.n_threads = value;
if (params.cpuparams.n_threads <= 0) {
+12 -12
View File
@@ -41,9 +41,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int>::min();
auto has_max = max_value != std::numeric_limits<int>::max();
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int64_t>::min();
auto has_max = max_value != std::numeric_limits<int64_t>::max();
auto digit_range = [&](char from, char to) {
out << "[";
@@ -159,7 +159,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
if (has_min) {
if (min_value < 0) {
out << "\"-\" (";
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
out << ") | [0] | [1-9] ";
more_digits(0, decimals_left - 1);
} else if (min_value == 0) {
@@ -194,7 +194,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
}
digit_range(c, c);
out << " (";
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
out << ")";
if (c < '9') {
out << " | ";
@@ -216,7 +216,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
} else {
out << "\"-\" (";
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
out << ")";
}
return;
@@ -925,17 +925,17 @@ public:
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
int min_value = std::numeric_limits<int>::min();
int max_value = std::numeric_limits<int>::max();
int64_t min_value = std::numeric_limits<int64_t>::min();
int64_t max_value = std::numeric_limits<int64_t>::max();
if (schema.contains("minimum")) {
min_value = schema["minimum"].get<int>();
min_value = schema["minimum"].get<int64_t>();
} else if (schema.contains("exclusiveMinimum")) {
min_value = schema["exclusiveMinimum"].get<int>() + 1;
min_value = schema["exclusiveMinimum"].get<int64_t>() + 1;
}
if (schema.contains("maximum")) {
max_value = schema["maximum"].get<int>();
max_value = schema["maximum"].get<int64_t>();
} else if (schema.contains("exclusiveMaximum")) {
max_value = schema["exclusiveMaximum"].get<int>() - 1;
max_value = schema["exclusiveMaximum"].get<int64_t>() - 1;
}
std::stringstream out;
out << "(";
+6 -2
View File
@@ -22,6 +22,7 @@ Legend:
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
@@ -41,6 +42,7 @@ Legend:
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
@@ -82,6 +84,7 @@ Legend:
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
@@ -97,8 +100,8 @@ Legend:
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
@@ -108,5 +111,6 @@ Legend:
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+16
View File
@@ -59,6 +59,14 @@
"CPU","EXP","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
"CPU","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","CPU"
"CPU","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
"CPU","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
@@ -119,6 +127,14 @@
"CPU","EXP","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
"CPU","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","CPU"
"CPU","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
"CPU","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=0","support","1","yes","CPU"
"CPU","REGLU","type=f16,ne_a=[5,7,11,13],v=0,swapped=0","support","1","yes","CPU"
"CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=1","support","1","yes","CPU"
Can't render this file because it is too large.
+16
View File
@@ -31,6 +31,14 @@
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","XIELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
"SYCL0","XIELU","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
@@ -95,6 +103,14 @@
"SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","XIELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
"SYCL0","XIELU","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
Can't render this file because it is too large.
+21 -21
View File
@@ -3263,27 +3263,27 @@
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=0","support","1","yes","Vulkan"
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=1","support","1","yes","Vulkan"
"Vulkan0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","1","yes","Vulkan"
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","1","yes","Vulkan"
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
Can't render this file because it is too large.
+1 -2
View File
@@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices,
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
+44
View File
@@ -577,6 +577,10 @@ extern "C" {
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_XIELU,
GGML_UNARY_OP_FLOOR,
GGML_UNARY_OP_CEIL,
GGML_UNARY_OP_ROUND,
GGML_UNARY_OP_TRUNC,
GGML_UNARY_OP_COUNT,
};
@@ -1151,6 +1155,46 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_floor(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_floor_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_ceil(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_ceil_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_round(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_round_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
/**
* Truncates the fractional part of each element in the tensor (towards zero).
* For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0
* Similar to std::trunc in C/C++.
*/
GGML_API struct ggml_tensor * ggml_trunc(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_trunc_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// xIELU activation function
// x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
// where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
+12
View File
@@ -307,6 +307,10 @@ function(ggml_add_cpu_backend_variant tag_name)
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
endif()
ggml_add_cpu_backend_variant_impl(${tag_name})
@@ -371,6 +375,14 @@ if (GGML_CPU_ALL_VARIANTS)
else()
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
else()
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
endif()
else()
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
endif()
+22
View File
@@ -598,6 +598,26 @@ static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor
return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
}
// free the extra space at the end if the new tensor is smaller
static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_tensor * node, struct ggml_tensor * parent) {
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
size_t parent_size = ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);
size_t node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
GGML_ASSERT(parent_size >= node_size);
if (parent_size > node_size) {
struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];
struct buffer_address p_addr = p_hn->addr;
p_addr.offset += node_size;
size_t extra_size = parent_size - node_size;
AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name);
ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent);
}
}
static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
GGML_ASSERT(buffer_id >= 0);
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
@@ -643,6 +663,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
hn->addr = p_hn->addr;
p_hn->allocated = false; // avoid freeing the parent
view_src_hn->allocated = false;
ggml_gallocr_free_extra_space(galloc, node, view_src);
return;
}
} else {
@@ -650,6 +671,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
hn->buffer_id = p_hn->buffer_id;
hn->addr = p_hn->addr;
p_hn->allocated = false; // avoid freeing the parent
ggml_gallocr_free_extra_space(galloc, node, parent);
return;
}
}
+46 -43
View File
@@ -51,28 +51,31 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
return ACL_DT_UNDEFINED;
}
aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
size_t* nb, int64_t dims, aclFormat format,
size_t offset) {
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne,
size_t * nb,
int64_t dims,
aclFormat format,
size_t offset) {
// If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
// added.
int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
if (ne == nullptr) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
acl_ne[i] = tensor->ne[i];
acl_ne[i] = tensor->ne[i];
// The step size of acl is in elements.
acl_stride[i] = tensor->nb[i] / ggml_element_size(tensor);
}
} else {
// With bcast
for (int i = 0; i < dims; i++) {
acl_ne[i] = ne[i];
acl_ne[i] = ne[i];
acl_stride[i] = nb[i] / ggml_element_size(tensor);
}
}
int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims);
int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims);
int64_t acl_storage_len = 1;
for (int i = 0; i < final_dims; i++) {
acl_storage_len += (acl_ne[i] - 1) * acl_stride[i];
@@ -84,15 +87,13 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
std::reverse(acl_ne, acl_ne + final_dims);
std::reverse(acl_stride, acl_stride + final_dims);
aclTensor* acl_tensor = aclCreateTensor(
acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
elem_offset, format, &acl_storage_len, 1,
tensor->data);
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
elem_offset, format, &acl_storage_len, 1, tensor->data);
return acl_tensor;
}
bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (t1->ne[i] != t0->ne[i] && t1->ne[i] != 1) {
return true;
@@ -101,15 +102,16 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
return false;
}
int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
const ggml_tensor* src1,
int64_t* bcast_src0_ne,
int64_t* bcast_src1_ne, size_t* bcast_src0_nb,
size_t* bcast_src1_nb) {
int64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0,
const ggml_tensor * src1,
int64_t * bcast_src0_ne,
int64_t * bcast_src1_ne,
size_t * bcast_src0_nb,
size_t * bcast_src1_nb) {
GGML_ASSERT(ggml_can_repeat(src1, src0));
int bcast_dim_cnt = 0;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
int64_t nr = src0->ne[i] / src1->ne[i];
int64_t nr = src0->ne[i] / src1->ne[i];
bcast_src0_ne[bcast_dim_cnt] = src0->ne[i] / nr;
bcast_src1_ne[bcast_dim_cnt] = src1->ne[i];
bcast_src0_nb[bcast_dim_cnt] = src0->nb[i];
@@ -119,21 +121,26 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
// Need to add an extra dim.
bcast_src0_ne[bcast_dim_cnt] = nr;
bcast_src1_ne[bcast_dim_cnt] = 1;
bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] *
bcast_src0_ne[bcast_dim_cnt - 1];
bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] *
bcast_src1_ne[bcast_dim_cnt - 1];
bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] * bcast_src0_ne[bcast_dim_cnt - 1];
bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] * bcast_src1_ne[bcast_dim_cnt - 1];
bcast_dim_cnt++;
}
}
return bcast_dim_cnt;
}
int64_t ggml_cann_get_mulmat_bcast_shape(
const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb) {
int64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne,
const int64_t * weight_ne,
const int64_t * dst_ne,
const size_t * input_nb,
const size_t * weight_nb,
const size_t * dst_nb,
int64_t * bcast_input_ne,
int64_t * bcast_weight_ne,
int64_t * bcast_dst_ne,
size_t * bcast_input_nb,
size_t * bcast_weight_nb,
size_t * bcast_dst_nb) {
// input and dst shoule in same shape, except first two dims.
GGML_ASSERT(input_ne[2] == dst_ne[2]);
GGML_ASSERT(input_ne[3] == dst_ne[3]);
@@ -148,34 +155,30 @@ int64_t ggml_cann_get_mulmat_bcast_shape(
// Do not use bcast in the first two dimensions because we only support
// the bcast batch dimension. Just copy them.
if (i < 2 || nr == 1) {
bcast_input_ne[bcast_dim_cnt] = input_ne[i];
bcast_input_ne[bcast_dim_cnt] = input_ne[i];
bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
bcast_dst_ne[bcast_dim_cnt] = dst_ne[i];
bcast_dst_ne[bcast_dim_cnt] = dst_ne[i];
bcast_input_nb[bcast_dim_cnt] = input_nb[i];
bcast_input_nb[bcast_dim_cnt] = input_nb[i];
bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
bcast_dim_cnt++;
} else {
// Need to add an extra dim.
bcast_input_ne[bcast_dim_cnt] = nr;
bcast_dst_ne[bcast_dim_cnt] = nr;
bcast_input_ne[bcast_dim_cnt] = nr;
bcast_dst_ne[bcast_dim_cnt] = nr;
bcast_weight_ne[bcast_dim_cnt] = 1;
bcast_input_nb[bcast_dim_cnt] = input_nb[i];
bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
bcast_input_nb[bcast_dim_cnt] = input_nb[i];
bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
bcast_dim_cnt++;
bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr;
bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr;
bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr;
bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr;
bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] *
bcast_input_ne[bcast_dim_cnt - 1];
bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] *
bcast_dst_ne[bcast_dim_cnt - 1];
bcast_weight_nb[bcast_dim_cnt] =
bcast_weight_nb[bcast_dim_cnt - 1] *
bcast_weight_ne[bcast_dim_cnt - 1];
bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] * bcast_input_ne[bcast_dim_cnt - 1];
bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] * bcast_dst_ne[bcast_dim_cnt - 1];
bcast_weight_nb[bcast_dim_cnt] = bcast_weight_nb[bcast_dim_cnt - 1] * bcast_weight_ne[bcast_dim_cnt - 1];
bcast_dim_cnt++;
}
}
Executable → Regular
+54 -43
View File
@@ -62,10 +62,12 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
* @return Pointer to the created ACL tensor.
*/
aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = nullptr,
size_t* nb = nullptr, int64_t dims = 0,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0);
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
int64_t * ne = nullptr,
size_t * nb = nullptr,
int64_t dims = 0,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0);
/**
* @brief Template for creating an ACL tensor from provided parameters. typename TYPE
@@ -87,12 +89,15 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
* @return Pointer to the created ACL tensor.
*/
template<typename TYPE>
aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
TYPE type_size, int64_t* ne, TYPE* nb,
int64_t dims,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0) {
template <typename TYPE>
aclTensor * ggml_cann_create_tensor(void * data_ptr,
aclDataType dtype,
TYPE type_size,
int64_t * ne,
TYPE * nb,
int64_t dims,
aclFormat format = ACL_FORMAT_ND,
size_t offset = 0) {
int64_t tmp_ne[GGML_MAX_DIMS * 2];
int64_t tmp_stride[GGML_MAX_DIMS * 2];
@@ -109,9 +114,8 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
std::reverse(tmp_ne, tmp_ne + dims);
std::reverse(tmp_stride, tmp_stride + dims);
aclTensor* acl_tensor =
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
format, &acl_storage_len, 1, data_ptr);
aclTensor * acl_tensor =
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
return acl_tensor;
}
@@ -132,7 +136,7 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
* to 1. If such a dimension is found, broadcasting is required to align t1
* with t0 for element-wise operations.
*/
bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1);
/**
* @brief Computes broadcast shapes and strides for two ggml_tensors.
@@ -187,19 +191,21 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
* dim1 in a inserted dim, should add nb for dim1,
* and all other nb moves to next in order.
*/
int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1,
int64_t* bcast_ne_src0, int64_t* bcast_ne_src1,
size_t* bcast_nb_src0, size_t* bcast_nb_src1);
int64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0,
const ggml_tensor * src1,
int64_t * bcast_ne_src0,
int64_t * bcast_ne_src1,
size_t * bcast_nb_src0,
size_t * bcast_nb_src1);
// Bcast macro to avoid duplicate code.
#define BCAST_SHAPE(src0, src1) \
int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2]; \
int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2]; \
size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2]; \
size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2]; \
int64_t bcast_dims = ggml_cann_get_bcast_shape( \
src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, bcast_##src0##_nb, \
bcast_##src1##_nb);
#define BCAST_SHAPE(src0, src1) \
int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2]; \
int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2]; \
size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2]; \
size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2]; \
int64_t bcast_dims = ggml_cann_get_bcast_shape(src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, \
bcast_##src0##_nb, bcast_##src1##_nb);
#define BCAST_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
@@ -233,26 +239,31 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* sr
* before cast dim.
* @sa ggml_cann_get_bcast_shape
*/
int64_t ggml_cann_get_mulmat_bcast_shape(
const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb);
int64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne,
const int64_t * weight_ne,
const int64_t * dst_ne,
const size_t * input_nb,
const size_t * weight_nb,
const size_t * dst_nb,
int64_t * bcast_input_ne,
int64_t * bcast_weight_ne,
int64_t * bcast_dst_ne,
size_t * bcast_input_nb,
size_t * bcast_weight_nb,
size_t * bcast_dst_nb);
// Bcast macro to avoid duplicate code.
#define BCAST_MUL_MAT_SHAPE(input, weight, dst) \
int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2]; \
int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2]; \
int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2]; \
size_t bcast_##input##_nb[GGML_MAX_DIMS * 2]; \
size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2]; \
size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2]; \
int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape( \
input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, \
bcast_##input##_ne, bcast_##weight##_ne, bcast_##dst##_ne, \
bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
#define BCAST_MUL_MAT_SHAPE(input, weight, dst) \
int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2]; \
int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2]; \
int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2]; \
size_t bcast_##input##_nb[GGML_MAX_DIMS * 2]; \
size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2]; \
size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2]; \
int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape( \
input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, bcast_##input##_ne, bcast_##weight##_ne, \
bcast_##dst##_ne, bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
#define BCAST_MUL_MAT_PARAM(tensor) \
bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
#define BCAST_MUL_MAT_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
#endif // CANN_ACL_TENSOR_H
Executable → Regular
+1181 -1327
View File
File diff suppressed because it is too large Load Diff
Executable → Regular
+189 -212
View File
@@ -62,7 +62,7 @@
* @param dst The ggml tensor representing the destination, which op is
* GGML_OP_REPEAT and specifies the desired dimensions.
*/
void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Applies the Leaky ReLU activation function to a tensor using the CANN
@@ -82,7 +82,7 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the result of the Leaky ReLU
* activation is stored, which op is `GGML_OP_LEAKY_RELU`
*/
void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_leaky_relu(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Concatenates multiple tensors along a specified dimension using the
@@ -97,7 +97,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @attention tensorList length should be 2 and the dimension using for concat
* default to 1.
*/
void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_concat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Generates a sequence of evenly spaced values within a specified
@@ -113,7 +113,7 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* `start`, 'stop' and 'step' are in dst->op_params and dst->op is
* `GGML_OP_ARANGE`.
*/
void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_arange(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Applies a clamp operation to the elements of a ggml tensor using the
@@ -131,7 +131,7 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the clamped values will be stored.
* dst->op is `GGML_OP_CLAMP`, `min` and `max` value is in dst->params.
*/
void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_clamp(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Scales the elements of a ggml tensor by a constant factor using the
@@ -148,7 +148,7 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the scaled values will be stored.
* dst->op is `GGML_OP_SCALE` and `scale` value is in dst->params.
*/
void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_scale(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Sorts the elements of a ggml tensor and returns the indices that
@@ -163,7 +163,7 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the sorted indices will be stored.
* dst->op is `GGML_OP_ARGSORT`.
*/
void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the Layer Normalization for a ggml tensor using the CANN
@@ -185,7 +185,7 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the normalized values will be stored.
* @attention `Var` defaults to dst->ne[0].
*/
void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the Group Normalization for a ggml tensor using the CANN
@@ -209,7 +209,7 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
*
* @attention eps defaults to 1e-6f.
*/
void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the accumulation of tensors using the CANN backend.
@@ -228,7 +228,7 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the accumulated values will be stored.
* `inplace` is in dst->params, and dst->op is `GGML_OP_ACC`.
*/
void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the sum of elements along the last dimension of a ggml tensor
@@ -244,7 +244,7 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
*
* @attention `reduce_dims` defaults to 3, which means the last dimension.
*/
void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the sum of elements in a ggml tensor.
@@ -258,7 +258,7 @@ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
*
*/
void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Upsamples a ggml tensor using nearest neighbor interpolation using
@@ -274,8 +274,7 @@ void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the upsampled values will be stored.
* dst->op is `GGML_OP_UPSCALE`.
*/
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
ggml_tensor* dst);
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Pads a ggml tensor to match the dimensions of the destination tensor
@@ -290,7 +289,7 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
* @param dst The destination tensor, which specifies the target dimensions for
* padding. dst->op is `GGML_OP_PAD`.
*/
void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_pad(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Executes a 2D pooling operation on a ggml tensor using the CANN
@@ -307,7 +306,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor on which the pooling operation is to be
* performed. dst->op is `GGML_OP_POOL_2D`.
*/
void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Duplicates a ggml tensor using the CANN backend.
@@ -326,7 +325,7 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* different shape and dst is no-contiguous.
* @note: This func need to simplify.
*/
void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_dup(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the Root Mean Square (RMS) normalization of a ggml tensor
@@ -348,7 +347,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the normalized values will be stored.
* dst->op is `GGML_OP_RMS_NORM`.
*/
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_rms_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Applies a diagonal mask to the tensor with a specified value.
@@ -363,7 +362,7 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* `GGML_OP_DIAG_MASK`
* @param value The value to use for masking.
*/
void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float value);
void ggml_cann_diag_mask(ggml_backend_cann_context & ctx, ggml_tensor * dst, float value);
/**
* @brief Performs an image-to-column transformation on the input tensor.
@@ -378,7 +377,7 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float
* @param dst The destination tensor that stores the result of the operation.
* dst->op is `GGML_OP_IM2COL`.
*/
void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_im2col(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes time step embeddings using sine and cosine functions.
@@ -392,10 +391,10 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the result of the embedding operation
* will be stored. dst->op is `GGML_OP_TIMESTEP_EMBEDDING`.
*/
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * dst);
// @see ggml_cann_dup.
void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the softmax activation with optional masking.
@@ -417,7 +416,7 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the result will be stored. dst->op is
* `GGML_OP_SOFTMAX`.
*/
void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Extracts specific rows from a tensor based on indices.
@@ -429,7 +428,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param ctx The backend CANN context for executing operations.
* @param dst The destination tensor where the extracted rows will be stored.
*/
void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Writes specific rows into a tensor at positions specified by indices.
@@ -441,7 +440,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param ctx The backend CANN context for executing operations.
* @param dst The destination tensor where the specified rows will be updated.
*/
void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Executes matrix multiplication for the given tensor.
@@ -454,7 +453,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor for storing the result of the matrix
* multiplication. dst->op is `GGML_OP_MUL_MAT`.
*/
void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Applies Rotary Positional Embedding (RoPE) to the input tensor.
@@ -477,7 +476,7 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @note The function currently does not support cases where the freq_scale is
* not equal 1.
*/
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the index of the maximum value along the specified dimension
@@ -492,7 +491,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the indices of the maximum values will
* be stored. dst->op is `GGML_OP_ARGMAX`.
*/
void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Adds two tensors element-wise and stores the result in a destination
@@ -509,8 +508,10 @@ void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param acl_src1 The second source tensor.
* @param acl_dst The destination tensor where the result will be stored.
*/
void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
aclTensor* acl_src1, aclTensor* acl_dst = nullptr);
void aclnn_add(ggml_backend_cann_context & ctx,
aclTensor * acl_src0,
aclTensor * acl_src1,
aclTensor * acl_dst = nullptr);
/**
* @brief Sub two tensors element-wise and stores the result in a destination
@@ -527,8 +528,10 @@ void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
* @param acl_src1 The second source tensor.
* @param acl_dst The destination tensor where the result will be stored.
*/
void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
aclTensor* acl_src1, aclTensor* acl_dst = nullptr);
void aclnn_sub(ggml_backend_cann_context & ctx,
aclTensor * acl_src0,
aclTensor * acl_src1,
aclTensor * acl_dst = nullptr);
/**
* @brief Performs element-wise multiplication of two tensors and stores the
@@ -546,8 +549,10 @@ void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
* @param acl_other The second tensor for element-wise multiplication.
* @param acl_dst The destination tensor where the result will be stored.
*/
void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_other, aclTensor* acl_dst = nullptr);
void aclnn_mul(ggml_backend_cann_context & ctx,
aclTensor * acl_src,
aclTensor * acl_other,
aclTensor * acl_dst = nullptr);
/**
* @brief Matrix division, optionally in-place.
@@ -567,8 +572,10 @@ void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
* @param inplace Flag indicating whether to perform the operation in-place on
* `acl_src`.
*/
void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_other, aclTensor* acl_dst = nullptr);
void aclnn_div(ggml_backend_cann_context & ctx,
aclTensor * acl_src,
aclTensor * acl_other,
aclTensor * acl_dst = nullptr);
/**
* @brief Applies element-wise cosine function to the elements of a tensor.
@@ -584,8 +591,7 @@ void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src,
* @param acl_dst The destination tensor where the cosine results will be
* stored.
*/
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst);
void aclnn_cos(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst);
/**
* @brief Applies element-wise sine function to the elements of a tensor.
@@ -602,8 +608,7 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
* @param acl_src The source tensor on which the sine function will be applied.
* @param acl_dst The destination tensor where the sine results will be stored.
*/
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst);
void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst);
/**
* @brief Prepares broadcast-compatible ACL tensors for two input tensors and one
@@ -621,8 +626,12 @@ void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
*/
void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst,
aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst);
void bcast_shape(ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
aclTensor ** acl_src0,
aclTensor ** acl_src1,
aclTensor ** acl_dst);
/**
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
@@ -637,7 +646,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst,
* @param dst The destination tensor where the transposed convolution result
* will be stored. dst->op is `GGML_OP_CONV_TRANSPOSE_1D`.
*/
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Applies the ELU (Exponential Linear Unit) activation to a ggml tensor
@@ -662,7 +671,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
* @param dst The destination tensor where the ELU-activated result will be stored.
* dst->op is expected to be `GGML_OP_ELU`.
*/
void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the mean of a ggml tensor element-wise using the CANN backend.
@@ -677,7 +686,7 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the mean result will be stored.
* dst->op is expected to be `GGML_OP_MEAN`.
*/
void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_mean(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Applies 1D reflect padding to a ggml tensor using the CANN backend.
@@ -692,7 +701,7 @@ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the padded result will be stored.
* dst->op is expected to be `GGML_OP_PAD_REFLECT_1D`.
*/
void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Counts the number of equal elements in two ggml tensors using the CANN backend.
@@ -708,7 +717,7 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the result will be stored.
* dst->op is expected to be `GGML_OP_COUNT_EQUAL`.
*/
void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Applies the Step activation function to a ggml tensor using the CANN backend.
@@ -723,7 +732,7 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the result will be stored.
* dst->op is expected to be `GGML_OP_STEP`.
*/
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Performs the Flash Attention extended operator using the CANN backend.
@@ -738,59 +747,46 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
* @param dst The destination tensor where the result will be stored.
* dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
*/
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/*
* @brief A generic wrapper for ACL resources with custom deleter support.
*/
using any_acl_resource = std::unique_ptr<void, std::function<void(void*)>>;
using any_acl_resource = std::unique_ptr<void, std::function<void(void *)>>;
/**
* @brief Trait structure used to define how to destroy a given ACL resource type.
*
* @tparam T ACL resource type.
*/
template<typename T>
struct acl_resource_traits;
template <typename T> struct acl_resource_traits;
/**
* @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
*/
template<>
struct acl_resource_traits<aclTensor> {
static void destroy(void* p) {
ACL_CHECK(aclDestroyTensor(static_cast<aclTensor*>(p)));
}
template <> struct acl_resource_traits<aclTensor> {
static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast<aclTensor *>(p))); }
};
/**
* @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
*/
template<>
struct acl_resource_traits<aclIntArray> {
static void destroy(void* p) {
ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray*>(p)));
}
template <> struct acl_resource_traits<aclIntArray> {
static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray *>(p))); }
};
/**
* @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
*/
template<>
struct acl_resource_traits<aclScalar> {
static void destroy(void* p) {
ACL_CHECK(aclDestroyScalar(static_cast<aclScalar*>(p)));
}
template <> struct acl_resource_traits<aclScalar> {
static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast<aclScalar *>(p))); }
};
/**
* @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
*/
template<>
struct acl_resource_traits<aclTensorList> {
static void destroy(void* p) {
ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList*>(p)));
}
template <> struct acl_resource_traits<aclTensorList> {
static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList *>(p))); }
};
/**
@@ -800,14 +796,8 @@ struct acl_resource_traits<aclTensorList> {
* @param ptr Raw pointer to ACL resource.
* @return any_acl_resource Smart pointer that handles destruction.
*/
template<typename T>
any_acl_resource make_acl_resource(T* ptr) {
return any_acl_resource(
static_cast<void*>(ptr),
[](void* p) {
acl_resource_traits<T>::destroy(p);
}
);
template <typename T> any_acl_resource make_acl_resource(T * ptr) {
return any_acl_resource(static_cast<void *>(ptr), [](void * p) { acl_resource_traits<T>::destroy(p); });
}
/**
@@ -817,8 +807,7 @@ any_acl_resource make_acl_resource(T* ptr) {
* @param vec Target vector to hold ACL resources.
* @param args Raw pointers to ACL resources.
*/
template<typename... Args>
void register_acl_resources(std::vector<any_acl_resource>& vec, Args*... args) {
template <typename... Args> void register_acl_resources(std::vector<any_acl_resource> & vec, Args *... args) {
(vec.emplace_back(make_acl_resource(args)), ...);
}
@@ -826,39 +815,36 @@ void register_acl_resources(std::vector<any_acl_resource>& vec, Args*... args) {
* @brief Task class that wraps the execution of an aclnn function call.
*/
class aclnn_task : public cann_task {
public:
aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr,
uint64_t workspace_size, aclOpExecutor * executor,
aclrtStream stream) :
aclnn_func_(aclnn_func),
workspace_addr_(workspace_addr),
workspace_size_(workspace_size),
executor_(executor),
stream_(stream) {}
virtual void run_task() override {
ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_));
}
private:
aclnn_func_t aclnn_func_;
void * workspace_addr_;
uint64_t workspace_size_;
aclOpExecutor * executor_;
aclrtStream stream_;
public:
aclnn_task(aclnn_func_t aclnn_func,
void * workspace_addr,
uint64_t workspace_size,
aclOpExecutor * executor,
aclrtStream stream) :
aclnn_func_(aclnn_func),
workspace_addr_(workspace_addr),
workspace_size_(workspace_size),
executor_(executor),
stream_(stream) {}
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
private:
aclnn_func_t aclnn_func_;
void * workspace_addr_;
uint64_t workspace_size_;
aclOpExecutor * executor_;
aclrtStream stream_;
};
/**
* @brief Task class that releases ACL resources after usage.
*/
class release_resource_task : public cann_task {
public:
release_resource_task(std::vector<any_acl_resource>&& resources){
resource_ = std::move(resources);
}
public:
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
virtual void run_task() override {
resource_.clear();
}
private:
virtual void run_task() override { resource_.clear(); }
private:
std::vector<any_acl_resource> resource_;
};
@@ -866,38 +852,40 @@ private:
* @brief Task class for performing asynchronous memory copy operations.
*/
class async_memcpy_task : public cann_task {
public:
async_memcpy_task(void* dst, const void* src, size_t size,
aclrtMemcpyKind kind, aclrtStream stream)
: dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {}
public:
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
dst_(dst),
src_(src),
size_(size),
kind_(kind),
stream_(stream) {}
virtual void run_task() override {
ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_));
}
private:
void* dst_;
const void* src_;
size_t size_;
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
private:
void * dst_;
const void * src_;
size_t size_;
aclrtMemcpyKind kind_;
aclrtStream stream_;
aclrtStream stream_;
};
/**
* @brief Task class for performing asynchronous memory set operations.
*/
class async_memset_task : public cann_task {
public:
async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream)
: buffer_(buffer), size_(size), value_(value), stream_(stream) {}
public:
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
buffer_(buffer),
size_(size),
value_(value),
stream_(stream) {}
virtual void run_task() override {
ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_));
}
private:
void* buffer_;
size_t size_;
int32_t value_;
aclrtStream stream_;
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
private:
void * buffer_;
size_t size_;
int32_t value_;
aclrtStream stream_;
};
/**
@@ -918,25 +906,24 @@ class async_memset_task : public cann_task {
* same stream are executed in queue order.
*/
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
do { \
uint64_t workspaceSize = 0; \
aclOpExecutor * executor; \
void * workspaceAddr = nullptr; \
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
if (workspaceSize > 0) { \
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
workspaceAddr = workspace_allocator.get(); \
} \
if (CTX.async_mode) { \
auto task = \
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, \
executor, CTX.stream()); \
CTX.task_queue.submit_task(std::move(task)); \
} else { \
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\
} \
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
do { \
uint64_t workspaceSize = 0; \
aclOpExecutor * executor; \
void * workspaceAddr = nullptr; \
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
if (workspaceSize > 0) { \
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
workspaceAddr = workspace_allocator.get(); \
} \
if (CTX.async_mode) { \
auto task = \
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
CTX.task_queue.submit_task(std::move(task)); \
} else { \
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
} \
} while (0)
/**
@@ -947,11 +934,10 @@ class async_memset_task : public cann_task {
* @param ctx Backend context which manages task submission and async mode.
* @param args Pointers to ACL resources to be released.
*/
template <typename... Args>
void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
std::vector<any_acl_resource> resources;
register_acl_resources(resources, std::forward<Args>(args)...);
if(ctx.async_mode) {
if (ctx.async_mode) {
auto task = std::make_unique<release_resource_task>(std::move(resources));
ctx.task_queue.submit_task(std::move(task));
}
@@ -966,8 +952,11 @@ void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... arg
* @param len Size of memory to copy (in bytes).
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
*/
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
const void * src, size_t len, aclrtMemcpyKind kind) {
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
void * dst,
const void * src,
size_t len,
aclrtMemcpyKind kind) {
if (ctx.async_mode) {
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
ctx.task_queue.submit_task(std::move(task));
@@ -976,8 +965,11 @@ inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
}
}
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
const void * src, size_t len, aclrtMemcpyKind kind) {
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
void * dst,
const void * src,
size_t len,
aclrtMemcpyKind kind) {
if (ctx->async_mode) {
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
ctx->task_queue.submit_task(std::move(task));
@@ -994,8 +986,7 @@ inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
* @param size Size of the memory buffer (in bytes).
* @param value Value to set in the buffer.
*/
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer,
size_t size, int value) {
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
if (ctx.async_mode) {
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
ctx.task_queue.submit_task(std::move(task));
@@ -1029,7 +1020,7 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
* @param dst The destination tensor where the expert-weighted token outputs are stored.
* Expected to be of shape [M, K, N, 1].
*/
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Check whether a tensor is a weight tensor for matrix multiplication.
@@ -1041,20 +1032,14 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
*
* @param tensor Pointer to the target ggml_tensor object (const-qualified).
*/
static bool is_matmul_weight(const ggml_tensor* tensor) {
std::string name = ggml_get_name(tensor);
static const std::unordered_set<std::string> weight_suffixes{
"output.weight",
"attn_q.weight",
"attn_k.weight",
"attn_v.weight",
"attn_output.weight",
"ffn_gate.weight",
"ffn_up.weight",
"ffn_down.weight"
};
static bool is_matmul_weight(const ggml_tensor * tensor) {
std::string name = ggml_get_name(tensor);
static const std::unordered_set<std::string> weight_suffixes{ "output.weight", "attn_q.weight",
"attn_k.weight", "attn_v.weight",
"attn_output.weight", "ffn_gate.weight",
"ffn_up.weight", "ffn_down.weight" };
for (const auto& suffix : weight_suffixes) {
for (const auto & suffix : weight_suffixes) {
if (name.find(suffix) != std::string::npos) {
return true;
}
@@ -1078,14 +1063,13 @@ static bool is_matmul_weight(const ggml_tensor* tensor) {
* @param ctx The CANN backend context used to manage execution and resources.
* @param dst The destination tensor.
*/
template <auto binary_op>
void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0];
ggml_tensor* src1 = dst->src[1];
template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
aclTensor* acl_src0;
aclTensor* acl_src1;
aclTensor* acl_dst;
aclTensor * acl_src0;
aclTensor * acl_src1;
aclTensor * acl_dst;
// Need bcast
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
@@ -1094,7 +1078,6 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
}
/**
* @brief Applies a unary operation to an input tensor using the CANN backend.
*
@@ -1107,12 +1090,12 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
* @param ctx The CANN backend context for managing resources and execution.
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
*/
template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src = dst->src[0];
template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
aclTensor* acl_src = ggml_cann_create_tensor(src);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
aclTensor * acl_src = ggml_cann_create_tensor(src);
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
unary_op(ctx, acl_src, acl_dst);
ggml_cann_release_resources(ctx, acl_src, acl_dst);
@@ -1138,9 +1121,9 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
*
* @see GGML_CANN_CALL_OP_UNARY
*/
void ggml_cann_op_unary(
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_op_unary(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,
ggml_backend_cann_context & ctx,
ggml_tensor * dst);
/**
* @brief Applies a gated (GLU-style) unary operation using the CANN backend.
@@ -1172,9 +1155,9 @@ void ggml_cann_op_unary(
*
* @see GGML_CANN_CALL_OP_UNARY_GATED
*/
void ggml_cann_op_unary_gated(
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
ggml_backend_cann_context& ctx, ggml_tensor* dst);
void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,
ggml_backend_cann_context & ctx,
ggml_tensor * dst);
/**
* @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
@@ -1197,16 +1180,13 @@ void ggml_cann_op_unary_gated(
* @see ggml_cann_op_unary
* @see GGML_CANN_CALL_ACLNN_OP
*/
#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
do { \
auto lambda = [](ggml_backend_cann_context& ctx, \
aclTensor* acl_src, \
aclTensor* acl_dst) { \
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
}; \
ggml_cann_op_unary(lambda, ctx, dst); \
} \
while (0)
#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
do { \
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
}; \
ggml_cann_op_unary(lambda, ctx, dst); \
} while (0)
/**
* @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
@@ -1229,15 +1209,12 @@ void ggml_cann_op_unary_gated(
* @see ggml_cann_op_unary_gated
* @see GGML_CANN_CALL_ACLNN_OP
*/
#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
do { \
auto lambda = [](ggml_backend_cann_context& ctx, \
aclTensor* acl_src, \
aclTensor* acl_dst) { \
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
}; \
ggml_cann_op_unary_gated(lambda, ctx, dst); \
} \
while (0)
#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
do { \
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
}; \
ggml_cann_op_unary_gated(lambda, ctx, dst); \
} while (0)
#endif // CANN_ACLNN_OPS
Executable → Regular
+92 -99
View File
@@ -44,7 +44,7 @@
#include "../include/ggml.h"
#include "../ggml-impl.h"
#define MATRIX_ROW_PADDING 512
#define MATRIX_ROW_PADDING 512
#define GGML_CANN_MAX_STREAMS 8
/**
@@ -56,8 +56,7 @@
* @param line The line number at which the error occurred.
* @param msg The error message.
*/
[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
const char* file, int line, const char* msg);
[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
/**
* @brief Checks the result of a CANN function call and invokes the error
@@ -89,25 +88,24 @@ struct ggml_cann_device_info {
* @brief Information about a single CANN device.
*/
struct cann_device_info {
int cc; /**< Compute capability. */
int cc; /**< Compute capability. */
size_t smpb; /**< Maximum shared memory per block. */
bool vmm; /**< Virtual memory support. */
bool vmm; /**< Virtual memory support. */
size_t vmm_granularity; /**< Granularity of virtual memory. */
size_t total_vram; /**< Total video RAM available on the device. */
};
cann_device_info devices[GGML_CANN_MAX_DEVICES] =
{}; /**< Array of CANN device information. */
cann_device_info devices[GGML_CANN_MAX_DEVICES] = {}; /**< Array of CANN device information. */
};
const ggml_cann_device_info& ggml_cann_info();
const ggml_cann_device_info & ggml_cann_info();
void ggml_cann_set_device(int32_t device);
void ggml_cann_set_device(int32_t device);
int32_t ggml_cann_get_device();
std::optional<std::string> get_env(const std::string& name);
bool parse_bool(const std::string& value);
int parse_integer(const std::string& value);
std::optional<std::string> get_env(const std::string & name);
bool parse_bool(const std::string & value);
int parse_integer(const std::string & value);
/**
* @brief Abstract base class for memory pools used by CANN.
@@ -126,7 +124,7 @@ struct ggml_cann_pool {
* will be stored.
* @return Pointer to the allocated memory block.
*/
virtual void* alloc(size_t size, size_t* actual_size) = 0;
virtual void * alloc(size_t size, size_t * actual_size) = 0;
/**
* @brief Frees a previously allocated memory block.
@@ -136,16 +134,16 @@ struct ggml_cann_pool {
* @note Note that all CANN opertors are running async. Make sure memory is
* still avaiable before this operator finished.
*/
virtual void free(void* ptr, size_t size) = 0;
virtual void free(void * ptr, size_t size) = 0;
};
/**
* @brief RAII wrapper for managing memory allocations from a CANN memory pool.
*/
struct ggml_cann_pool_alloc {
ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */
void* ptr = nullptr; /**< Pointer to the allocated memory block. */
size_t actual_size = 0; /**< Actual size of the allocated memory block. */
ggml_cann_pool * pool = nullptr; /**< Pointer to the memory pool. */
void * ptr = nullptr; /**< Pointer to the allocated memory block. */
size_t actual_size = 0; /**< Actual size of the allocated memory block. */
/**
* @brief Default constructor.
@@ -156,16 +154,14 @@ struct ggml_cann_pool_alloc {
* @brief Constructor that initializes the memory pool.
* @param pool Reference to the memory pool.
*/
explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {}
explicit ggml_cann_pool_alloc(ggml_cann_pool & pool) : pool(&pool) {}
/**
* @brief Constructor that initializes the memory pool and allocates memory.
* @param pool Reference to the memory pool.
* @param size Size of the memory block to allocate.
*/
ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) {
alloc(size);
}
ggml_cann_pool_alloc(ggml_cann_pool & pool, size_t size) : pool(&pool) { alloc(size); }
/**
* @brief Destructor that frees the allocated memory block.
@@ -181,7 +177,7 @@ struct ggml_cann_pool_alloc {
* @param size Size of the memory block to allocate.
* @return Pointer to the allocated memory block.
*/
void* alloc(size_t size) {
void * alloc(size_t size) {
GGML_ASSERT(pool != nullptr);
GGML_ASSERT(ptr == nullptr);
ptr = pool->alloc(size, &this->actual_size);
@@ -194,7 +190,7 @@ struct ggml_cann_pool_alloc {
* @param size Size of the memory block to allocate.
* @return Pointer to the allocated memory block.
*/
void* alloc(ggml_cann_pool& pool, size_t size) {
void * alloc(ggml_cann_pool & pool, size_t size) {
this->pool = &pool;
return alloc(size);
}
@@ -203,25 +199,25 @@ struct ggml_cann_pool_alloc {
* @brief Gets the pointer to the allocated memory block.
* @return Pointer to the allocated memory block.
*/
void* get() { return ptr; }
void * get() { return ptr; }
// Deleted copy constructor
ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete;
ggml_cann_pool_alloc(const ggml_cann_pool_alloc &) = delete;
// Deleted move constructor
ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete;
ggml_cann_pool_alloc(ggml_cann_pool_alloc &&) = delete;
// Deleted copy assignment operator
ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete;
ggml_cann_pool_alloc & operator=(const ggml_cann_pool_alloc &) = delete;
// Deleted move assignment operator
ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
};
/**
* @brief Function pointer type for ACLNN operator calls.
*/
using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
/**
* @brief Base class for all CANN tasks to be submitted to the task queue.
@@ -229,7 +225,7 @@ using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStrea
* Users should override the run_task() method with actual task logic.
*/
class cann_task {
public:
public:
virtual void run_task() {}
};
@@ -237,16 +233,20 @@ public:
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
*/
class cann_task_queue {
public:
public:
/**
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
*
* @param capacity Queue capacity. Must be a power of 2.
* @param device Target device ID (used for context setting).
*/
explicit cann_task_queue(size_t capacity, int32_t device)
: buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
running_(false), device_(device) {
explicit cann_task_queue(size_t capacity, int32_t device) :
buffer_(capacity),
capacity_(capacity),
head_(0),
tail_(0),
running_(false),
device_(device) {
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
mask_ = capacity_ - 1;
}
@@ -257,7 +257,7 @@ public:
* @param item Unique pointer to the task.
* @return true if the task was successfully enqueued, false if the queue was full.
*/
bool enqueue(std::unique_ptr<cann_task>&& item) {
bool enqueue(std::unique_ptr<cann_task> && item) {
size_t next_tail = (tail_ + 1) & mask_;
if (next_tail == head_) {
@@ -276,17 +276,16 @@ public:
*
* @param task Task to be submitted.
*/
void submit_task(std::unique_ptr<cann_task>&& task) {
while(!enqueue(std::move(task))) {
void submit_task(std::unique_ptr<cann_task> && task) {
while (!enqueue(std::move(task))) {
std::this_thread::yield();
continue;
}
if (!running_) {
running_ = true;
thread_ = std::thread(&cann_task_queue::execute, this);
thread_ = std::thread(&cann_task_queue::execute, this);
}
}
/**
@@ -309,7 +308,7 @@ public:
}
}
private:
private:
/**
* @brief Worker thread function that continuously dequeues and executes tasks.
*/
@@ -317,7 +316,7 @@ private:
ggml_cann_set_device(device_);
while (running_) {
if(head_ == tail_) {
if (head_ == tail_) {
std::this_thread::yield();
continue;
}
@@ -330,24 +329,24 @@ private:
}
std::vector<std::unique_ptr<cann_task>> buffer_;
const size_t capacity_;
size_t mask_;
size_t head_;
size_t tail_;
bool running_;
std::thread thread_;
int32_t device_;
const size_t capacity_;
size_t mask_;
size_t head_;
size_t tail_;
bool running_;
std::thread thread_;
int32_t device_;
};
#ifdef USE_ACL_GRAPH
struct ggml_graph_node_properties {
// dst tensor
void * node_address;
void * node_address;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
// src tensor
void * src_address[GGML_MAX_SRC];
void * src_address[GGML_MAX_SRC];
int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
@@ -376,13 +375,11 @@ struct ggml_cann_graph {
* move existing graphs to the front (most recently used), and clear the cache.
*/
struct ggml_cann_graph_lru_cache {
size_t capacity; /**< Maximum number of graphs in the cache. */
size_t capacity; /**< Maximum number of graphs in the cache. */
std::list<ggml_cann_graph*> cache_list; /**< List storing cached graphs as raw pointers. */
std::list<ggml_cann_graph *> cache_list; /**< List storing cached graphs as raw pointers. */
ggml_cann_graph_lru_cache() {
capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12"));
}
ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); }
/**
* @brief Push a new graph to the front of the cache.
@@ -390,11 +387,11 @@ struct ggml_cann_graph_lru_cache {
* @param new_node Pointer to the new ggml_cann_graph to cache.
* Ownership is transferred to the cache (cache will delete it).
*/
void push(ggml_cann_graph* new_node) {
void push(ggml_cann_graph * new_node) {
if (cache_list.size() >= capacity) {
ggml_cann_graph* old = cache_list.back();
ggml_cann_graph * old = cache_list.back();
cache_list.pop_back();
delete old; // free the old graph
delete old; // free the old graph
}
cache_list.push_front(new_node);
}
@@ -403,7 +400,7 @@ struct ggml_cann_graph_lru_cache {
* @brief Move an existing graph to the front of the cache.
* @param node Pointer to the ggml_cann_graph to move.
*/
void move_to_front(ggml_cann_graph* node) {
void move_to_front(ggml_cann_graph * node) {
cache_list.remove(node);
cache_list.push_front(node);
}
@@ -421,92 +418,89 @@ struct ggml_cann_graph_lru_cache {
/**
* @brief Destructor that clears the cache and frees all cached graphs.
*/
~ggml_cann_graph_lru_cache() {
clear();
}
~ggml_cann_graph_lru_cache() { clear(); }
};
#endif // USE_ACL_GRAPH
struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() {
if(theta_scale_cache != nullptr) {
if (theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(theta_scale_cache));
}
if(sin_cache != nullptr) {
if (sin_cache != nullptr) {
ACL_CHECK(aclrtFree(sin_cache));
}
if(cos_cache != nullptr) {
if (cos_cache != nullptr) {
ACL_CHECK(aclrtFree(cos_cache));
}
}
void* theta_scale_cache = nullptr;
void * theta_scale_cache = nullptr;
int64_t theta_scale_length = 0;
// sin/cos cache, used only to accelerate first layer on each device
void* sin_cache = nullptr;
void* cos_cache = nullptr;
int64_t position_length = 0;
void * sin_cache = nullptr;
void * cos_cache = nullptr;
int64_t position_length = 0;
// Properties to check before reusing the sincos cache
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
};
struct ggml_cann_tensor_cache {
~ggml_cann_tensor_cache() {
if(cache != nullptr) {
if (cache != nullptr) {
ACL_CHECK(aclrtFree(cache));
}
}
void* cache = nullptr;
int64_t size = 0;
void * cache = nullptr;
int64_t size = 0;
};
/**
* @brief Context for managing CANN backend operations.
*/
struct ggml_backend_cann_context {
int32_t device; /**< Device ID. */
std::string name; /**< Name of the device. */
std::string description; /**< Description of the device. */
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
int32_t device; /**< Device ID. */
std::string name; /**< Name of the device. */
std::string description; /**< Description of the device. */
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
ggml_cann_graph_lru_cache graph_lru_cache;
bool acl_graph_mode = true;
bool acl_graph_mode = true;
#endif
cann_task_queue task_queue;
bool async_mode;
cann_task_queue task_queue;
bool async_mode;
// Rope Cache
ggml_cann_rope_cache rope_cache;
ggml_cann_rope_cache rope_cache;
// Constant Pool
ggml_cann_tensor_cache rms_norm_one_tensor_cache;
ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */
/**
* @brief Constructor for initializing the context with a given device.
* @param device Device ID.
*/
explicit ggml_backend_cann_context(int device)
: device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
explicit ggml_backend_cann_context(int device) :
device(device),
name("CANN" + std::to_string(device)),
task_queue(1024, device) {
ggml_cann_set_device(device);
description = aclrtGetSocName();
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF");
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
#ifdef USE_ACL_GRAPH
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n",
__func__, device,
acl_graph_mode ? "GRAPH" : "EAGER",
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
#endif
}
@@ -549,8 +543,7 @@ struct ggml_backend_cann_context {
aclrtStream stream() { return stream(0); }
// TODO: each stream should have a memory pool.
std::unique_ptr<ggml_cann_pool>
mem_pool; /**< Memory pool for the device. */
std::unique_ptr<ggml_cann_pool> mem_pool; /**< Memory pool for the device. */
/**
* @brief Create a new memory pool for a given device.
@@ -563,7 +556,7 @@ struct ggml_backend_cann_context {
* @brief Get or create the memory pool for the context.
* @return Reference to the memory pool.
*/
ggml_cann_pool& pool() {
ggml_cann_pool & pool() {
if (mem_pool == nullptr) {
mem_pool = new_pool_for_device(device);
}
Executable → Regular
+501 -608
View File
File diff suppressed because it is too large Load Diff
+36 -20
View File
@@ -466,29 +466,45 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
message(STATUS "s390x detected")
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
list(APPEND GGML_CPU_SOURCES
ggml-cpu/arch/s390/quants.c)
# TODO: Separation to determine activation of VX/VXE/VXE2
if (${S390X_M} MATCHES "8561|8562")
message(STATUS "z15 target")
list(APPEND ARCH_FLAGS -march=z15)
elseif (${S390X_M} MATCHES "3931")
message(STATUS "z16 target")
list(APPEND ARCH_FLAGS -march=z16)
elseif (${S390X_M} MATCHES "9175|9176")
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
message(STATUS "z17 target")
list(APPEND ARCH_FLAGS -march=arch15)
else()
message(STATUS "Unknown target")
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
list(APPEND ARCH_FLAGS -march=native -mtune=native)
# for native compilation
if (GGML_NATIVE)
# check machine level to determine target
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
# TODO: Separation to determine activation of VX/VXE/VXE2
if (${S390X_M} MATCHES "8561|8562")
message(STATUS "z15 target")
list(APPEND ARCH_FLAGS -march=z15)
elseif (${S390X_M} MATCHES "3931")
message(STATUS "z16 target")
list(APPEND ARCH_FLAGS -march=z16)
elseif (${S390X_M} MATCHES "9175|9176")
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
message(STATUS "z17 target")
list(APPEND ARCH_FLAGS -march=arch15)
else()
message(STATUS "Unknown target")
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
list(APPEND ARCH_FLAGS -march=native -mtune=native)
endif()
# for cross-compilation
elseif(GGML_CPU_ALL_VARIANTS)
# range through IBM z15 to z17
# NOTE: update when a new hardware level is released
foreach (ZHW RANGE 15 17)
if(DEFINED GGML_INTERNAL_Z${ZHW})
message(STATUS "z${ZHW} cross-compile target")
list(APPEND ARCH_FLAGS -march=z${ZHW})
endif()
endforeach()
endif()
if (GGML_VXE)
if (GGML_VXE OR GGML_INTERNAL_VXE)
message(STATUS "VX/VXE/VXE2 enabled")
list(APPEND ARCH_FLAGS -mvx -mzvector)
list(APPEND ARCH_DEFINITIONS GGML_VXE)
+10 -2
View File
@@ -2184,6 +2184,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
{
n_tasks = 1;
} break;
@@ -3563,13 +3567,17 @@ void ggml_cpu_init(void) {
#ifdef GGML_USE_OPENMP
//if (!getenv("OMP_WAIT_POLICY")) {
// // set the wait policy to active, so that OpenMP threads don't sleep
// putenv("OMP_WAIT_POLICY=active");
// setenv("OMP_WAIT_POLICY", "active", 0)
//}
if (!getenv("KMP_BLOCKTIME")) {
// set the time to wait before sleeping a thread
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
putenv("KMP_BLOCKTIME=200"); // 200ms
#ifdef _WIN32
_putenv_s("KMP_BLOCKTIME", "200"); // 200ms
#else
setenv("KMP_BLOCKTIME", "200", 0); // 200ms
#endif
}
#endif
}
+16
View File
@@ -8993,6 +8993,22 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_exp(params, dst);
} break;
case GGML_UNARY_OP_FLOOR:
{
ggml_compute_forward_floor(params, dst);
} break;
case GGML_UNARY_OP_CEIL:
{
ggml_compute_forward_ceil(params, dst);
} break;
case GGML_UNARY_OP_ROUND:
{
ggml_compute_forward_round(params, dst);
} break;
case GGML_UNARY_OP_TRUNC:
{
ggml_compute_forward_trunc(params, dst);
} break;
case GGML_UNARY_OP_XIELU:
{
ggml_compute_forward_xielu(params, dst);
+3 -2
View File
@@ -485,8 +485,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
int32_t start = ith * task_per_thread;
int32_t end = std::min((ith + 1) * task_per_thread, task_count);
for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
int32_t gemm_idx = compute_idx / block_size_m;
int32_t m_idx = compute_idx % block_size_m * block_size_m;
int32_t gemm_idx = compute_idx / per_gemm_block_count_m;
int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
int32_t m_idx = block_idx_in_gemm * block_size_m;
const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx];
int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
+32
View File
@@ -73,6 +73,22 @@ static inline float op_log(float x) {
return logf(x);
}
static inline float op_floor(float x) {
return floorf(x);
}
static inline float op_ceil(float x) {
return ceilf(x);
}
static inline float op_round(float x) {
return roundf(x);
}
static inline float op_trunc(float x) {
return truncf(x);
}
template <float (*op)(float), typename src0_t, typename dst_t>
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
@@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
unary_op<op_log>(params, dst);
}
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_floor>(params, dst);
}
void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_ceil>(params, dst);
}
void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_round>(params, dst);
}
void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_trunc>(params, dst);
}
void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
const float alpha_n = ggml_get_op_params_f32(dst, 1);
const float alpha_p = ggml_get_op_params_f32(dst, 2);
+4
View File
@@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
#ifdef __cplusplus
-7
View File
@@ -944,13 +944,6 @@ struct ggml_cuda_graph {
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
bool use_cpy_indirection = false;
std::vector<char *> cpy_dest_ptrs;
char ** dest_ptrs_d;
int dest_ptrs_size = 0;
// Index to allow each cpy kernel to be aware of it's position within the graph
// relative to other cpy nodes.
int graph_cpynode_index = -1;
#endif
};
+55 -163
View File
@@ -8,18 +8,16 @@
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
template <cpy_kernel_t cpy_1>
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
const int nb12, const int nb13) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets
const int64_t i03 = i/(ne00 * ne01 * ne02);
@@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
const int nb12, const int nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) {
return;
}
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
@@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
const int nb12, const int nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) {
return;
}
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
@@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
cpy_blck(cx + x_offset, cdst + dst_offset);
}
// Copy destination pointers to GPU to be available when pointer indirection is in use
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
CUDA_CHECK(cudaStreamSynchronize(stream));
if (cuda_graph->dest_ptrs_d != nullptr) {
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
}
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
}
// copy destination pointers to GPU
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
cuda_graph->graph_cpynode_index = 0; // reset index
#else
GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream);
#endif
}
template<typename src_t, typename dst_t>
static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0;
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q8_0_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0;
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q4_0_f32_cuda(
@@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda(
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1;
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q4_1_f32_cuda(
@@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda(
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0;
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q5_0_f32_cuda(
@@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda(
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1;
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_q5_1_f32_cuda(
@@ -259,25 +233,25 @@ static void ggml_cpy_q5_1_f32_cuda(
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL;
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -311,16 +285,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data;
char ** dest_ptrs_d = nullptr;
int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
}
#else
GGML_UNUSED(disable_indirection_for_this_node);
#endif
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
@@ -329,134 +293,62 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{
if (src0->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
}
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
}
#else
GGML_UNUSED(disable_indirection_for_this_node);
#endif
}
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
bool disable_indirection = true;
ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
}
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
// Prioritize CUDA graph compatibility over direct memory copy optimization.
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
if (src0->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else {
return nullptr;
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<float, half>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<half, half>>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<half, float>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
}
ggml_cuda_cpy(ctx, src0, dst);
}
+1 -5
View File
@@ -2,10 +2,6 @@
#define CUDA_CPY_BLOCK_SIZE 64
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
+2 -7
View File
@@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
const int nwarps = nthreads / WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = false;
constexpr bool need_f16_V = false;
const bool need_f16_K = type_K == GGML_TYPE_F16;
const bool need_f16_V = type_V == GGML_TYPE_F16;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
}
@@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+12 -7
View File
@@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
}
}
#define FATTN_VEC_CASE(D, type_K, type_V) \
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
return; \
} \
#define FATTN_VEC_CASE(D, type_K, type_V) \
{ \
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
return; \
} \
} \
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
FATTN_VEC_CASE( 64, type_K, type_V) \
@@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
#endif // GGML_CUDA_FA_ALL_QUANTS
switch (K->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
break;
case GGML_TYPE_Q4_1:
@@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// If Turing tensor cores available, use them:
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
return BEST_FATTN_KERNEL_VEC;
}
@@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
if (Q->ne[1] == 1) {
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_VEC;
+14 -33
View File
@@ -273,6 +273,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
turing_devices_without_mma.push_back({ id, device_name });
}
// Temporary performance fix:
// Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
// TODO: Check for future drivers the default scheduling strategy and
// remove this call again when cudaDeviceScheduleSpin is default.
if (prop.major == 12 && prop.minor == 1) {
CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
}
#endif // defined(GGML_USE_HIP)
}
@@ -2633,11 +2642,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
}
#ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
bool use_cuda_graph) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
@@ -2688,33 +2696,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#endif
}
if (node->op == GGML_OP_CPY) {
// Store the pointers which are updated for each token, such that these can be sent
// to the device and accessed using indirection from CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
}
}
if (!use_cuda_graph) {
break;
}
}
if (use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = true;
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
}
return use_cuda_graph;
}
@@ -2733,7 +2719,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_CPY &&
node->op != GGML_OP_VIEW) {
return false;
}
@@ -2754,7 +2739,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_CPY &&
node->op != GGML_OP_VIEW
) {
return false;
@@ -2901,7 +2885,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
//if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
return false;
}
@@ -3120,7 +3104,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {
@@ -3147,10 +3131,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}
if (!use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = false;
}
#else
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
@@ -3645,9 +3625,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_ACC:
return true;
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
return op->src[0]->ne[0] <= 1024;
+40 -6
View File
@@ -1,5 +1,7 @@
#include "ggml.h"
#include "mmf.cuh"
#include "mmid.cuh"
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
@@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
mmf_ids_data ids_info{};
mmf_ids_data * ids_info_ptr = nullptr;
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_dst = ids ? ne1 : ne2;
@@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
nchannels_y = ids->ne[0];
}
if (ids && ncols_dst > 16) {
const int64_t n_expert_used = ids->ne[0];
const int64_t n_experts = ne02;
const int64_t n_tokens = ne12;
const int64_t ne_get_rows = n_tokens * n_expert_used;
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
const int si1 = static_cast<int>(ids_s1);
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
GGML_ASSERT(sis1 > 0);
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
CUDA_CHECK(cudaGetLastError());
ids_info.ids_src_compact = ids_src_compact_dev.get();
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
ids_info.expert_bounds_dev = expert_bounds_dev.get();
ids_info.n_experts = static_cast<int>(n_experts);
ids_info.sis1 = sis1;
ids_info_ptr = &ids_info;
}
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
@@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
@@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
@@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
}
if (mul_mat_id) {
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
if (src0_ne[1] <= 1024 && src1_ncols > 512) {
return false;
}
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
} else if(src0_ne[1] > 1024 && src1_ncols > 128) {
return false;
}
} else {
+313 -31
View File
@@ -7,6 +7,14 @@ using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr;
const int32_t * ids_dst_compact = nullptr;
const int32_t * expert_bounds_dev = nullptr;
int n_experts = 0;
int sis1 = 0;
};
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
@@ -224,6 +232,250 @@ static __global__ void mul_mat_f(
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
//This kernel is for larger batch sizes of mul_mat_id
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f_ids(
const T * __restrict__ x, const float * __restrict__ y,
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int expert_idx = blockIdx.y;
const int expert_start = expert_bounds[expert_idx];
const int expert_end = expert_bounds[expert_idx + 1];
const int ncols_expert = expert_end - expert_start;
const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
const int tile_idx = blockIdx.z;
if (tile_idx >= tiles_for_expert) {
return;
}
const int col_base = tile_idx * cols_per_block;
GGML_UNUSED(channel_ratio);
const int channel_x = expert_idx;
const int sample_dst = 0;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
y += int64_t(sample_y) *stride_sample_y;
dst += int64_t(sample_dst)*stride_sample_dst;
const int32_t * ids_src_expert = ids_src_compact + expert_start;
const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
extern __shared__ char data_mmv[];
char * compute_base = data_mmv;
//const float2 * y2 = (const float2 *) y;
tile_C C[ntA][ntB];
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
if constexpr (std::is_same_v<T, float>) {
float vals_buf[2][tile_B::I];
auto gather_tile = [&](int tile_idx_local, float *vals) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + tile_idx_local*tile_B::I;
const int global_j = col_base + j;
float val = 0.0f;
if (j < cols_per_block && global_j < ncols_expert) {
const int src_entry = ids_src_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
const int token = (int) qrm.x;
const int channel = (int) qrm.y;
if (token < ncols_dst_total) {
val = y[channel*stride_channel_y + token*stride_col_y + col];
}
}
vals[j0] = val;
}
};
gather_tile(0, vals_buf[0]);
int curr_buf = 0;
int next_buf = 1;
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
}
if (itB + 1 < ntB) {
gather_tile(itB + 1, vals_buf[next_buf]);
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
if (itB + 1 < ntB) {
curr_buf ^= 1;
next_buf ^= 1;
}
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
float2 vals_buf[2][tile_B::I];
auto gather_tile = [&](int tile_idx_local, float2 *vals) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + tile_idx_local*tile_B::I;
const int global_j = col_base + j;
float2 tmp = make_float2(0.0f, 0.0f);
if (j < cols_per_block && global_j < ncols_expert) {
const int src_entry = ids_src_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
const int token = (int) qrm.x;
const int channel = (int) qrm.y;
if (token < ncols_dst_total) {
tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
}
}
vals[j0] = tmp;
}
};
if (ntB > 0) {
gather_tile(0, vals_buf[0]);
}
int curr_buf = 0;
int next_buf = 1;
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const float2 tmp = vals_buf[curr_buf][j0];
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
if (itB + 1 < ntB) {
gather_tile(itB + 1, vals_buf[next_buf]);
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
if (itB + 1 < ntB) {
curr_buf ^= 1;
next_buf ^= 1;
}
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
const int global_j = col_base + j;
if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
const int dst_entry = ids_dst_expert[global_j];
const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
const int token = (int) qrm.x;
if (token < ncols_dst_total) {
const int slot = (int) qrm.y;
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
}
}
}
#else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template<typename T, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
@@ -232,13 +484,35 @@ static inline void mul_mat_f_switch_ids(
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
if (ids) {
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
const mmf_ids_data * ids_data) {
const bool has_ids_data = ids_data && ids_data->ids_src_compact;
// Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
// we prefer the normal mul_mat_f path with has_ids=true.
if (has_ids_data && ncols_dst > 16) {
const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
if (max_tiles == 0) {
return;
}
dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
sis1_fd, nch_fd);
} else if (ids) {
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles;
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
@@ -258,7 +532,7 @@ void mul_mat_f_cuda(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
cudaStream_t stream, const mmf_ids_data * ids_data) {
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
@@ -290,7 +564,7 @@ void mul_mat_f_cuda(
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
@@ -300,49 +574,57 @@ void mul_mat_f_cuda(
mul_mat_f_switch_ids<T, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
default: {
GGML_ABORT("fatal error");
@@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
cudaStream_t stream, const mmf_ids_data * ids_data) {
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
@@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block(
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
default: {
GGML_ABORT("fatal error");
@@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream);
cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
+164
View File
@@ -0,0 +1,164 @@
#include "common.cuh"
#include "mmid.cuh"
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct mm_ids_helper_store {
uint32_t data;
__device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
data = (it & 0x003FFFFF) | (iex_used << 22);
}
__device__ uint32_t it() const {
return data & 0x003FFFFF;
}
__device__ uint32_t iex_used() const {
return data >> 22;
}
};
static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
const int expert = blockIdx.x;
extern __shared__ char data_mm_ids_helper[];
mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
int nex_prev = 0; // Number of columns for experts with a lower index.
int it_compact = 0; // Running index for the compact slice of this expert.
if constexpr (n_expert_used_template == 0) {
// Generic implementation:
for (int it = 0; it < n_tokens; ++it) {
int iex_used = -1; // The index at which the expert is used, if any.
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
const int expert_used = ids[it*si1 + iex];
nex_prev += expert_used < expert;
if (expert_used == expert) {
iex_used = iex;
}
}
if (iex_used != -1) {
store[it_compact] = mm_ids_helper_store(it, iex_used);
}
if (warp_reduce_any<warp_size>(iex_used != -1)) {
it_compact++;
}
}
} else {
// Implementation optimized for specific numbers of experts used:
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
const int it = it0 + threadIdx.x / neu_padded;
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
ids[it*si1 + iex] : INT_MAX;
const int iex_used = expert_used == expert ? iex : -1;
nex_prev += expert_used < expert;
// Whether the threads at this token position have used the expert:
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int it_compact_add_lower = 0;
#pragma unroll
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
it_compact_add_lower += tmp;
}
}
if (iex_used != -1) {
store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
}
}
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
const mm_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
}
if (threadIdx.x != 0) {
return;
}
expert_bounds[expert] = nex_prev;
if (expert < static_cast<int>(gridDim.x) - 1) {
return;
}
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
static void launch_mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
const int id = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
GGML_ASSERT(nbytes_shared <= smpbo);
mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
void ggml_cuda_launch_mm_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
switch (n_expert_used) {
case 2:
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 4:
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 6:
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 8:
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 16:
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
case 32:
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
default:
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
break;
}
}
+5
View File
@@ -0,0 +1,5 @@
#pragma once
void ggml_cuda_launch_mm_ids_helper(
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);
+3 -166
View File
@@ -1,141 +1,6 @@
#include "mmq.cuh"
#include "quantize.cuh"
#include <vector>
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct mmq_ids_helper_store {
uint32_t data;
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
data = (it & 0x003FFFFF) | (iex_used << 22);
}
__device__ uint32_t it() const {
return data & 0x003FFFFF;
}
__device__ uint32_t iex_used() const {
return data >> 22;
}
};
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
const int expert = blockIdx.x;
extern __shared__ char data_mmq_ids_helper[];
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
int nex_prev = 0; // Number of columns for experts with a lower index.
int it_compact = 0; // Running index for the compact slice of this expert.
if constexpr (n_expert_used_template == 0) {
// Generic implementation:
for (int it = 0; it < n_tokens; ++it) {
int iex_used = -1; // The index at which the expert is used, if any.
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
const int expert_used = ids[it*si1 + iex];
nex_prev += expert_used < expert;
if (expert_used == expert) {
iex_used = iex;
}
}
if (iex_used != -1) {
store[it_compact] = mmq_ids_helper_store(it, iex_used);
}
if (warp_reduce_any<warp_size>(iex_used != -1)) {
it_compact++;
}
}
} else {
// Implementation optimized for specific numbers of experts used:
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
const int it = it0 + threadIdx.x / neu_padded;
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
ids[it*si1 + iex] : INT_MAX;
const int iex_used = expert_used == expert ? iex : -1;
nex_prev += expert_used < expert;
// Whether the threads at this token position have used the expert:
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int it_compact_add_lower = 0;
#pragma unroll
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
it_compact_add_lower += tmp;
}
}
if (iex_used != -1) {
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
}
}
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
const mmq_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
}
if (threadIdx.x != 0) {
return;
}
expert_bounds[expert] = nex_prev;
if (expert < static_cast<int>(gridDim.x) - 1) {
return;
}
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
static void launch_mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
const int id = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
GGML_ASSERT(nbytes_shared <= smpbo);
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
#include "mmid.cuh"
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
@@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q(
const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11;
switch (n_expert_used) {
case 2:
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 4:
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 6:
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 8:
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 16:
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 32:
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
default:
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
}
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
CUDA_CHECK(cudaGetLastError());
}
+44 -28
View File
@@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const int row = blockIdx.x;
const int channel_dst = blockIdx.y;
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
const int sample_y = sample_dst;
const int tid = threadIdx.x;
@@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x*tmpy.x;
sumf[j] += tmpx.y*tmpy.y;
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
}
}
} else if constexpr (std::is_same_v<T, half>) {
@@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x * tmpy.x;
sumf[j] += tmpx.y * tmpy.y;
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
}
}
} else {
@@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
#endif // FP16_AVAILABLE
}
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
//TODO: add support for ggml_cuda_mad for hip_bfloat162
#if defined(GGML_USE_HIP)
const int * x2 = (const int *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
}
}
#else
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const nv_bfloat162 tmpx = x2[col2];
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
}
}
#endif
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
@@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
@@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
case 32: {
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
+23 -19
View File
@@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float wt_sum = 0.f;
extern __shared__ float data_topk_shared[];
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
float output_weights[experts_per_thread];
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
@@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
wt_shared_ptr[k] = max_val;
ids[k] = max_expert;
ids[k] = max_expert;
if constexpr (with_norm) {
wt_sum += max_val;
}
@@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
const float inv_sum = 1.0f / wt_sum;
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
output_weights[i] *= inv_sum;
}
}
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
weights[i] = wt_shared_ptr[i];
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
if (idx < n_expert_used) {
weights[idx] = output_weights[i];
}
}
}
@@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
cudaStream_t stream = ctx.stream();
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
switch (n_expert) {
case 1:
topk_moe_cuda<1, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 2:
topk_moe_cuda<2, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 4:
topk_moe_cuda<4, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 8:
topk_moe_cuda<8, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 16:
topk_moe_cuda<16, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 32:
topk_moe_cuda<32, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 64:
topk_moe_cuda<64, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 128:
topk_moe_cuda<128, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 256:
topk_moe_cuda<256, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 512:
topk_moe_cuda<512, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
default:
GGML_ASSERT(false && "fatal error");
+4 -2
View File
@@ -28,8 +28,10 @@ if (CXX_IS_HIPCC)
" Prefer setting the HIP compiler directly. See README for details.")
endif()
else()
# Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
endif()
cmake_minimum_required(VERSION 3.21)
+11 -2
View File
@@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
static inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) {
const struct ggml_tensor * node = cgraph->nodes[node_idx];
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
return 0;
}
return cgraph->use_counts[hash_pos];
}
// return true if the node's results are only used by N other nodes
// and can be fused into their calculations.
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
const struct ggml_tensor * node = cgraph->nodes[node_idx];
// check the use count against how many we're replacing
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
return false;
}
+25
View File
@@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
char base[256];
char name[256];
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_UPSCALE);
+1
View File
@@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
+22 -11
View File
@@ -7,6 +7,8 @@
#include <Metal/Metal.h>
#include <stdatomic.h>
#ifndef TARGET_OS_VISION
#define TARGET_OS_VISION 0
#endif
@@ -22,6 +24,9 @@
// overload of MTLGPUFamilyMetal3 (not available in some environments)
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
// virtual address for GPU memory allocations
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
#if !GGML_METAL_EMBED_LIBRARY
// Here to assist with NSBundle Path Hack
@interface GGMLMetalClass : NSObject
@@ -648,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SCALE:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_CONV_TRANSPOSE_2D:
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
@@ -657,6 +667,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
@@ -693,7 +704,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return true;
case GGML_OP_FLASH_ATTN_EXT:
// for new head sizes, add checks here
if (op->src[0]->ne[0] != 40 &&
if (op->src[0]->ne[0] != 32 &&
op->src[0]->ne[0] != 40 &&
op->src[0]->ne[0] != 64 &&
op->src[0]->ne[0] != 80 &&
op->src[0]->ne[0] != 96 &&
@@ -826,7 +838,7 @@ struct ggml_metal_buffer_wrapper {
};
struct ggml_metal_buffer {
void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985
void * all_data;
size_t all_size;
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
@@ -964,14 +976,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
if (shared) {
res->all_data = ggml_metal_host_malloc(size_aligned);
res->is_shared = true;
res->owned = true;
} else {
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below
res->all_data = (void *) 0x000000400ULL;
// use virtual address from g_addr_device counter
res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
res->is_shared = false;
}
res->all_size = size_aligned;
res->owned = true;
res->device = ggml_metal_device_get_obj(dev);
res->queue = ggml_metal_device_get_queue(dev);
@@ -982,15 +995,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
res->buffers[0].metal = nil;
if (size_aligned > 0) {
if (props_dev->use_shared_buffers &&shared) {
if (props_dev->use_shared_buffers && shared) {
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
length:size_aligned
options:MTLResourceStorageModeShared
deallocator:nil];
} else {
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
}
}
@@ -1138,7 +1149,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
if (buf->is_shared) {
memset((char *)tensor->data + offset, value, size);
memset((char *) tensor->data + offset, value, size);
return;
}
@@ -1167,7 +1178,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
if (buf->is_shared) {
memcpy((char *)tensor->data + offset, data, size);
memcpy((char *) tensor->data + offset, data, size);
return;
}
@@ -1222,7 +1233,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
if (buf->is_shared) {
memcpy(data, (const char *)tensor->data + offset, size);
memcpy(data, (const char *) tensor->data + offset, size);
return;
}
+14
View File
@@ -251,6 +251,7 @@ typedef struct {
int32_t sect_1;
int32_t sect_2;
int32_t sect_3;
bool src2;
} ggml_metal_kargs_rope;
typedef struct {
@@ -513,6 +514,19 @@ typedef struct {
uint64_t nb1;
} ggml_metal_kargs_conv_transpose_1d;
typedef struct {
int32_t IC;
int32_t IH;
int32_t IW;
int32_t KH;
int32_t KW;
int32_t OC;
int32_t s0;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_conv_transpose_2d;
typedef struct {
uint64_t ofs0;
uint64_t ofs1;
+75 -1
View File
@@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
} break;
case GGML_OP_CONV_TRANSPOSE_2D:
{
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
} break;
case GGML_OP_UPSCALE:
{
n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -866,12 +870,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
int nth = 32; // SIMD width
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, (int) n);
const int nsg = (nth + 31) / 32;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
return 1;
}
@@ -2969,6 +2986,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
/* sect_1 =*/ sect_1,
/* sect_2 =*/ sect_2,
/* sect_3 =*/ sect_3,
/* src2 =*/ op->src[2] != nullptr,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
@@ -3104,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
const int32_t IC = op->src[1]->ne[2];
const int32_t IH = op->src[1]->ne[1];
const int32_t IW = op->src[1]->ne[0];
const int32_t KH = op->src[0]->ne[1];
const int32_t KW = op->src[0]->ne[0];
const int32_t OW = op->ne[0];
const int32_t OH = op->ne[1];
const int32_t OC = op->ne[2];
ggml_metal_kargs_conv_transpose_2d args = {
/*.IC =*/ IC,
/*.IH =*/ IH,
/*.IW =*/ IW,
/*.KH =*/ KH,
/*.KW =*/ KW,
/*.OC =*/ OC,
/*.s0 =*/ s0,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
// Metal requires buffer size to be multiple of 16 bytes
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
return 1;
}
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
+1
View File
@@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
+234 -59
View File
@@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (tiitg != 0) {
if (args.np == 0) {
return;
}
float acc = 0.0f;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
const uint nsg = (ntg.x + 31) / 32;
float sumf = 0;
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
}
dst[0] = acc;
sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
}
template <bool norm>
@@ -3748,7 +3778,7 @@ kernel void kernel_rope_norm(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@@ -3801,7 +3831,7 @@ kernel void kernel_rope_neox(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@@ -3872,7 +3902,7 @@ kernel void kernel_rope_multi(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@@ -3939,7 +3969,7 @@ kernel void kernel_rope_vision(
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
@@ -4149,6 +4179,97 @@ kernel void kernel_conv_transpose_1d<half>(
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
kernel void kernel_upscale_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
@@ -5213,8 +5334,30 @@ kernel void kernel_flash_attn_ext(
half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8
#define FA_TYPES_F32 \
half, half4, simdgroup_half8x8, \
float, float4x4, simdgroup_float8x8, \
float, float4x4, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float2, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
@@ -5227,6 +5370,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
@@ -5239,6 +5383,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
@@ -5250,6 +5395,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
@@ -5261,6 +5407,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
@@ -5272,6 +5419,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
@@ -5283,6 +5431,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
@@ -5818,77 +5967,103 @@ kernel void kernel_flash_attn_ext_vec(
float, float4, \
float4
#define FA_TYPES_F32 \
half4, \
float4, \
float4, \
float, \
float, float4, \
float4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
#undef FA_TYPES
+3
View File
@@ -91,8 +91,11 @@ set(GGML_OPENCL_KERNELS
mul_mv_id_q8_0_f32_flat
mul_mv_id_mxfp4_f32
mul_mv_id_mxfp4_f32_flat
gemm_moe_mxfp4_f32
gemv_moe_mxfp4_f32
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q8_0_f32_l4_lm
mul
norm
relu
+262 -9
View File
@@ -402,12 +402,14 @@ struct ggml_backend_opencl_context {
cl_program program_conv_2d_f32;
cl_program program_conv_2d_f16_f32;
cl_program program_tsembd;
cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
cl_program program_mul_mv_id_mxfp4_f32;
cl_program program_mul_mv_id_mxfp4_f32_flat;
cl_program program_mul_mm_f32_f32_l4_lm;
cl_program program_mul_mm_f16_f32_l4_lm;
cl_program program_mul_mm_q8_0_f32_l4_lm;
cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
@@ -451,7 +453,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mat_f16_f32_tiled;
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
cl_kernel kernel_convert_block_q4_0_noshuffle;
@@ -474,12 +476,14 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_conv_2d_f32;
cl_kernel kernel_conv_2d_f16_f32;
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
cl_kernel kernel_mul_mv_id_mxfp4_f32;
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
std::vector<ProfilingInfo> profiling_info;
@@ -557,14 +561,14 @@ struct ggml_backend_opencl_context {
fprintf(ftrace, "[\n");
for (const ProfilingInfo & info : profiling_info) {
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n",
info.kernel_name.c_str(), info.cmd_queued/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n",
info.kernel_name.c_str(), info.cmd_submit/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n",
info.kernel_name.c_str(), info.cmd_start/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n",
info.kernel_name.c_str(), info.cmd_end/1000);
}
fclose(ftrace);
@@ -775,6 +779,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
@@ -1191,6 +1197,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mm_q8_0_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mm_q8_0_f32_l4_lm.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl");
#endif
backend_ctx->program_mul_mm_q8_0_f32_l4_lm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err));
GGML_LOG_CONT(".");
}
// mul
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1973,6 +1995,42 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
GGML_LOG_CONT(".");
}
std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
" -cl-mad-enable "
" -cl-fast-relaxed-math";
// gemv_moe_mxfp4_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemv_moe_mxfp4_f32.cl.h"
};
#else
const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl");
#endif
backend_ctx->program_gemv_moe_mxfp4_f32 =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err));
GGML_LOG_CONT(".");
}
// gemm_moe_mxfp4_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemm_moe_mxfp4_f32.cl.h"
};
#else
const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl");
#endif
backend_ctx->program_gemm_moe_mxfp4_f32 =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err));
GGML_LOG_CONT(".");
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
GGML_LOG_CONT("\n");
}
@@ -2686,7 +2744,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] &&
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
!ggml_are_same_shape(mul->src[0], rms_norm)) {
return false;
}
@@ -3281,6 +3339,12 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
GGML_UNUSED(backend_ctx);
int ne01 = tensor->ne[1];
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
}
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
@@ -3583,14 +3647,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
int ne02 = tensor->ne[2];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01));
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
size_t local_work_size[3] = {64, 2, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clReleaseMemObject(data_device));
tensor->extra = extra;
return;
}
#endif
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[3] = {64, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
@@ -3606,7 +3695,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
{ extra->q }
};
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
tensor->extra = extra;
return;
@@ -3733,6 +3821,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
int ne02 = tensor->ne[2];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
size_t local_work_size[3] = {64, 2, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(
queue, data_device, CL_TRUE, offset,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
}
#endif
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
@@ -6961,6 +7076,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q8_0: {
if (ne11 < 32) {
break;
}
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
int batch_stride_a = ne00*ne01;
int batch_stride_b = ne10*ne11;
int batch_stride_d = ne0*ne1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
default:
break;
}
@@ -7497,6 +7650,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
const int ne21 = src2->ne[1];
const cl_ulong nb21 = src2->nb[1];
const cl_ulong nb20 = src2->nb[0];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
@@ -7636,6 +7790,105 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
break;
}
case GGML_TYPE_MXFP4: {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, src0)) {
cl_int status;
size_t local_size[3] = {64, 2, 1};
size_t global_size[3] = {64, 2, 1};
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
int tile_size = 320;
if (ne12 == 1) { // for gemv
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
// create a sub_buffer for src2
cl_buffer_region region;
region.origin = offset2;
region.size = ne20 * ne21 * sizeof(int);
buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
// preprocess router table
int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
void * host_src2 = malloc(ne21 * nb21);
CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
int total_experts = nb21 / nb20;
int out_idx = 0;
for (int i_expert = 0; i_expert < ne02; i_expert++) {
for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
for (int j = 0; j < ne21; j++) {
for (int i = 0; i < ne20; i++) {
int expert = ((int *)host_src2)[j * total_experts + i];
if (i_expert == expert) {
((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
out_idx += 4;
}
}
}
}
}
buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(tile_size);
global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
}
// create a sub_buffer for src1
cl_buffer_region region;
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// create image for src1
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
if (ne12 == 1) {
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
} else {
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
}
// launch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
// deallocate sub buffers and images
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
CL_CHECK(clReleaseMemObject(buf_src1_image));
CL_CHECK(clReleaseMemObject(buf_src2));
return;
} // else fallback to generic kernel
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
+42
View File
@@ -147,6 +147,27 @@ kernel void kernel_convert_block_mxfp4(
}
}
kernel void kernel_convert_block_mxfp4_trans(
global struct block_mxfp4 * src0,
__global uint4 * dst_q,
__global uchar * dst_e,
uint ne00,
uint ne01
) {
int i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_MXFP4;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
global struct block_mxfp4 * b = src0 + src_blk_offset;
dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];
dst_e[dst_blk_offset] = b->e;
}
kernel void kernel_restore_block_mxfp4(
global uchar * src_q,
global half * src_e,
@@ -162,6 +183,27 @@ kernel void kernel_restore_block_mxfp4(
}
}
kernel void kernel_restore_block_mxfp4_trans(
__global uint4 * src_q,
__global uchar * src_e,
global struct block_mxfp4 * dst,
uint ne00,
uint ne01
) {
int i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_MXFP4;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
global struct block_mxfp4 * b = dst + dst_blk_offset;
((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
b->e = src_e[src_blk_offset];
}
//------------------------------------------------------------------------------
// block_q8_0
//------------------------------------------------------------------------------
@@ -4,6 +4,7 @@
#define ACC_TYPE4 float4
#define DATA_TYPE float
#define DATA_TYPE4 float4
#define MASK_DATA_TYPE half
#define CONVERT_ACC4(x) (x)
#define CONVERT_DATA4(x) (x)
@@ -148,7 +149,7 @@ __kernel void flash_attn_f32(
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
@@ -281,7 +282,7 @@ __kernel void flash_attn_f32_q1(
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
@@ -317,7 +318,7 @@ __kernel void flash_attn_f32_q1(
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
@@ -0,0 +1,162 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_MXFP4 32
#define N_SIMDGROUP 2
#define SIMDGROUP_WIDTH 64
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemm_moe_mxfp4_f32(
__global uint4 * src0_q,
__global uchar * src0_e,
__read_only image1d_buffer_t src1,
__global ushort4 * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int tile_size
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
ushort4 router = src2[i20];
ushort expert_id = router.x;
ushort i11 = router.y;
ushort i1 = router.z;
ushort tile_id = router.w;
if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size
return;
}
uint expert_offset = expert_id * ne00 * ne01 / 32;
uint tile_offset = expert_offset + tile_id * tile_size + i01;
__private float sum = 0.0f; // each thread calculate partial sum of one output
// loop along ne00 in block granularity, skip 4 blocks every iter
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
// load one block of q
uint4 regQ = src0_q[tile_offset + ib00 * ne01];
// convert 8 fp4 to fp16
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
uint offset = i11 * ne00 / 4 + ib00 * 8;
float4 shared_y4;
shared_y4 = read_imagef(src1, (offset + 0));
float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 4));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
shared_y4 = read_imagef(src1, (offset + 1));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 5));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
shared_y4 = read_imagef(src1, (offset + 2));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 6));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
shared_y4 = read_imagef(src1, (offset + 3));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 7));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
uchar regE = src0_e[tile_offset + ib00 * ne01];
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
// if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
// if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 outputs per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + tile_id * tile_size + i1 * ne01] = sum;
}
}
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_MXFP4 32
#define N_SIMDGROUP 4
#define SIMDGROUP_WIDTH 64
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemv_moe_mxfp4_f32(
__global uint4 * src0_q,
__global uchar * src0_e,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne11
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
uint i11 = i20 % ne11;
uint expert_id = src2[i20];
uint expert_offset = expert_id * ne00 * ne01 / 32;
__private float sum = 0.0f; // each thread calculate partial sum of one output
// loop along ne00 in block granularity, skip 4 blocks every iter
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
// load one block of q
uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01];
uint offset = i11 * ne00 / 4 + ib00 * 8;
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
float4 shared_y4;
shared_y4 = read_imagef(src1, (offset + 0));
float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 4));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
shared_y4 = read_imagef(src1, (offset + 1));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 5));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
shared_y4 = read_imagef(src1, (offset + 2));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 6));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
shared_y4 = read_imagef(src1, (offset + 3));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 7));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 outputs per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + i20 * ne01] = sum;
}
}
@@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (loadc_a + l < ne01) {
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
} else {
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0h;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0h;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0h;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0h;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
if (loadc_b + l < ne11) {
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0h;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0h;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0h;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0h;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
@@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
if (loadc_a + l < ne01) {
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
} else {
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
if (loadc_b + l < ne11) {
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
@@ -0,0 +1,154 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 4
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q8_0_f32_l4_lm(
global char4 * src0_q,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 8;
int iqs = idx % 8;
float d = (float)src0_d[ib];
global char4 * qs = src0_q + ib*8 + iqs;
char4 q = *qs;
float4 v = convert_float4(q)*d;
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3;
} else {
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
+24 -15
View File
@@ -939,6 +939,7 @@ public:
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
@@ -1458,6 +1459,20 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
return true;
}
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
size_t free, total;
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
ggml_backend_dev_memory(dev, &free, &total);
response.free_mem = free;
response.total_mem = total;
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
return true;
}
rpc_server::~rpc_server() {
for (auto buffer : buffers) {
ggml_backend_buffer_free(buffer);
@@ -1465,7 +1480,7 @@ rpc_server::~rpc_server() {
}
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
sockfd_t sockfd) {
rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
@@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
auto dev_id = request.device;
if (dev_id >= backends.size()) {
rpc_msg_get_device_memory_rsp response;
if (!server.get_device_memory(request, response)) {
return;
}
rpc_msg_get_device_memory_rsp response;
response.free_mem = free_mem[dev_id];
response.total_mem = total_mem[dev_id];
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
response.free_mem, response.total_mem);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
@@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
}
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices,
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
if (n_devices == 0 || devices == nullptr) {
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
return;
}
std::vector<ggml_backend_t> backends;
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION,
@@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
printf("Devices:\n");
for (size_t i = 0; i < n_devices; i++) {
auto dev = devices[i];
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
total / 1024 / 1024, free / 1024 / 1024);
auto backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
@@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
printf("Accepted client connection\n");
fflush(stdout);
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
rpc_serve_client(backends, cache_dir, client_socket->fd);
printf("Client connection closed\n");
fflush(stdout);
}
+152
View File
@@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
}
template<typename T>
static __dpct_inline__ T op_floor(T x) {
return sycl::floor(x);
}
template<typename T>
static __dpct_inline__ T op_ceil(T x) {
return sycl::ceil(x);
}
template<typename T>
static __dpct_inline__ T op_round(T x) {
return sycl::round(x);
}
template<typename T>
static __dpct_inline__ T op_trunc(T x) {
return sycl::trunc(x);
}
template<typename T>
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl:
}
}
template<typename T>
static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_floor(x[i]);
}
}
template<typename T>
static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_ceil(x[i]);
}
}
template<typename T>
static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_round(x[i]);
}
}
template<typename T>
static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_trunc(x[i]);
}
}
template<typename T>
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
@@ -397,6 +445,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
});
}
template<typename T>
static void arange_kernel(T * dst, const int k, T start, T step,
const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = start + static_cast<T>(i) * step;
}
}
template<typename T>
static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
@@ -565,6 +621,25 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
}
static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float start, stop, step;
memcpy(&start, dst->op_params, sizeof(float));
memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));
memcpy(&step, (float *) dst->op_params + 2, sizeof(float));
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
float * dst_ptr = (float *)dst->data;
const int k = (int)ggml_nelements(dst);
const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
arange_kernel(dst_ptr, k, start, step, item_ct1);
});
}
} // namespace ggml_sycl_detail
@@ -870,6 +945,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
}, min_val, max_val);
}
static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
}
static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
}
static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
}
static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
}
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
@@ -1090,3 +1217,28 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_geglu_quick(ctx, dst);
}
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
}
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_floor(ctx, dst);
}
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_ceil(ctx, dst);
}
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_round(ctx, dst);
}
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_trunc(ctx, dst);
}
+6
View File
@@ -80,5 +80,11 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_ELEMENTWISE_HPP
+65
View File
@@ -42,6 +42,7 @@
#include "ggml-sycl/presets.hpp"
#include "ggml-sycl/gemm.hpp"
#include "ggml-sycl/set_rows.hpp"
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/sycl_hw.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/quantize.hpp"
@@ -2151,6 +2152,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
}
inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
const int64_t ncols = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
main_stream->parallel_for(
sycl::range<1>(nrows),
[=](sycl::id<1> row) {
dst_dd[row] /= ncols;
}
);
}
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_I32);
@@ -3535,6 +3560,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
ggml_sycl_op_sum_rows(ctx, dst);
}
static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
ggml_sycl_op_mean(ctx, dst);
}
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
@@ -3589,6 +3620,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_GET_ROWS:
ggml_sycl_get_rows(ctx, dst);
break;
case GGML_OP_SET:
ggml_sycl_op_set(ctx, dst);
break;
case GGML_OP_SET_ROWS:
ggml_sycl_op_set_rows(ctx, dst);
break;
@@ -3664,6 +3698,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_UNARY_OP_ELU:
ggml_sycl_elu(ctx, dst);
break;
case GGML_UNARY_OP_FLOOR:
ggml_sycl_floor(ctx, dst);
break;
case GGML_UNARY_OP_CEIL:
ggml_sycl_ceil(ctx, dst);
break;
case GGML_UNARY_OP_ROUND:
ggml_sycl_round(ctx, dst);
break;
case GGML_UNARY_OP_TRUNC:
ggml_sycl_trunc(ctx, dst);
break;
default:
return false;
}
@@ -3784,6 +3830,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_sycl_sum_rows(ctx, dst);
break;
case GGML_OP_MEAN:
ggml_sycl_mean(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_sycl_argsort(ctx, dst);
break;
@@ -3799,6 +3848,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_ARANGE:
ggml_sycl_arange(ctx, dst);
break;
default:
return false;
}
@@ -4222,6 +4274,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
#if defined (GGML_SYCL_F16)
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
#else
@@ -4295,6 +4351,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
}
case GGML_OP_SET:
return (op->type == GGML_TYPE_F32) &&
(op->src[0] && op->src[1]) &&
(op->src[0]->type == GGML_TYPE_F32) &&
(op->src[1]->type == GGML_TYPE_F32);
case GGML_OP_SET_ROWS:
{
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
@@ -4431,6 +4493,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_POOL_2D:
@@ -4444,6 +4507,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
case GGML_OP_ARANGE:
return op->type == GGML_TYPE_F32;
default:
return false;
}
+2
View File
@@ -31,6 +31,7 @@
#define SYCL_SQRT_BLOCK_SIZE 256
#define SYCL_SIN_BLOCK_SIZE 256
#define SYCL_SQR_BLOCK_SIZE 256
#define SYCL_SET_BLOCK_SIZE 256
#define SYCL_CPY_BLOCK_SIZE 32
#define SYCL_SCALE_BLOCK_SIZE 256
#define SYCL_CLAMP_BLOCK_SIZE 256
@@ -49,6 +50,7 @@
#define SYCL_ARGMAX_BLOCK_SIZE 256
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
#define SYCL_ARANGE_BLOCK_SIZE 256
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_SYCL_DMMV_X
+73
View File
@@ -0,0 +1,73 @@
#include "presets.hpp"
#include "common.hpp"
#include "ggml.h"
#include "set.hpp"
#include <cstdint>
#include <sycl/sycl.hpp>
using namespace sycl;
// Internal function: perform element-wise set operation for each thread
inline void set_f32(const float* src, float* dst,
const int64_t ne0, const int64_t ne1,
const int64_t ne2, const int64_t ne3,
const int64_t nb[3], const int64_t src_nb[3],
const int64_t offset_elem,
const nd_item<1>& item)
{
const size_t idx = item.get_global_id(0);
const size_t total = ne0 * ne1 * ne2 * ne3;
if (idx >= total) return;
// Convert linear index to 4D indices
const size_t i3 = idx / (ne2 * ne1 * ne0);
const size_t rem = idx % (ne2 * ne1 * ne0);
const size_t i2 = rem / (ne1 * ne0);
const size_t rem2 = rem % (ne1 * ne0);
const size_t i1 = rem2 / ne0;
const size_t i0 = rem2 % ne0;
// Compute source and destination indices and copy
dst[i0 + i1*nb[0] + i2*nb[1] + i3*nb[2] + offset_elem] =
src[i0 + i1*src_nb[0] + i2*src_nb[1] + i3*src_nb[2]];
}
// Main function: prepare GPU queue and launch parallel_for
void ggml_sycl_op_set(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor* src0 = dst->src[0];
const ggml_tensor* src1 = dst->src[1];
// Ensure shapes and types are compatible
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
GGML_ASSERT(dst->type == src0->type && src0->type == src1->type && dst->type == GGML_TYPE_F32);
const int32_t* opts = (const int32_t*) dst->op_params;
const int64_t nb[3] = {opts[0]/sizeof(float), opts[1]/sizeof(float), opts[2]/sizeof(float)};
const int64_t offset_elem = opts[3] / sizeof(float);
const bool inplace = opts[4];
float* dst_ptr = (float*) dst->data;
const float* src0_ptr = (const float*) src0->data;
const float* src1_ptr = (const float*) src1->data;
queue_ptr stream = ctx.stream();
// Copy src0 to dst if not inplace
if (!inplace)
stream->memcpy(dst_ptr, src0_ptr, ggml_nbytes(dst));
const int64_t ne[4] = {src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]};
const int64_t src_nb[3] = {src1->nb[1]/sizeof(float), src1->nb[2]/sizeof(float), src1->nb[3]/sizeof(float)};
const size_t total_threads = ne[0]*ne[1]*ne[2]*ne[3];
const size_t grid_size = ((total_threads + SYCL_SET_BLOCK_SIZE - 1) / SYCL_SET_BLOCK_SIZE) * SYCL_SET_BLOCK_SIZE;
// Copy src0 to dst if not inplace
stream->parallel_for(
nd_range<1>(range<1>(grid_size), range<1>(SYCL_SET_BLOCK_SIZE)),
[=](nd_item<1> item) {
set_f32(src1_ptr, dst_ptr,
ne[0], ne[1], ne[2], ne[3],
nb, src_nb, offset_elem, item); }
);
}
+5
View File
@@ -0,0 +1,5 @@
#pragma once
#include "backend.hpp"
#include "ggml.h"
void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+9
View File
@@ -1,9 +1,18 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
cmake_policy(SET CMP0116 NEW)
if (POLICY CMP0147)
# Parallel build custom build steps
cmake_policy(SET CMP0147 NEW)
endif()
find_package(Vulkan COMPONENTS glslc REQUIRED)
if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
# Parallel build object files
add_definitions(/MP)
endif()
function(detect_host_compiler)
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
+495 -14
View File
@@ -385,6 +385,14 @@ enum shader_reduction_mode {
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -582,6 +590,9 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
@@ -595,6 +606,9 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
// [2] is {!norm, norm}
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
std::vector<vk_pipeline_ref> all_pipelines;
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
@@ -938,6 +952,11 @@ struct vk_op_multi_add_push_constants {
static_assert(MAX_PARAMETER_COUNT == 12);
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
};
struct vk_op_add_id_push_constants {
uint32_t ne0;
uint32_t ne1;
@@ -1087,6 +1106,19 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t C;
uint32_t H;
};
struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
uint32_t nb42, nb43, nb52, nb53;
uint32_t s_off;
uint32_t n_head, d_head, n_group, n_tok;
};
struct vk_op_ssm_conv_push_constants {
uint32_t nb01, nb02;
uint32_t nb11;
uint32_t dst_nb0, dst_nb1, dst_nb2;
uint32_t nc, ncs, nr, n_t, n_s;
};
struct vk_op_conv2d_push_constants {
uint32_t Cout;
@@ -2649,11 +2681,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
} \
}
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
@@ -2661,6 +2695,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
@@ -3588,6 +3623,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -3698,6 +3738,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
}
for (auto &c : compiles) {
c.wait();
}
@@ -7457,8 +7502,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
if (k->type == GGML_TYPE_F32) {
k_stride /= 4;
}
if (v->type == GGML_TYPE_F32) {
v_stride /= 4;
}
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
bool aligned = (KV % alignment) == 0 &&
@@ -7972,6 +8025,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
return ctx->device->pipeline_topk_moe[idx][with_norm];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
}
@@ -8087,6 +8147,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_SSM_SCAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t d_state = src0->ne[0];
if (d_state == 128) {
return ctx->device->pipeline_ssm_scan_f32_d128;
} else if (d_state == 256) {
return ctx->device->pipeline_ssm_scan_f32_d256;
}
}
return nullptr;
case GGML_OP_SSM_CONV:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_ssm_conv_f32;
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32;
@@ -8581,6 +8656,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
break;
case GGML_OP_SSM_CONV:
{
const uint32_t nr = src0->ne[1];
const uint32_t n_t = dst->ne[1];
const uint32_t n_s = dst->ne[2];
elements = { nr, n_t, n_s };
}
break;
default:
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
break;
@@ -9027,6 +9110,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
);
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = dst->src[3];
const ggml_tensor * src4 = dst->src[4];
const ggml_tensor * src5 = dst->src[5];
GGML_ASSERT(dst->buffer != nullptr);
const uint32_t head_dim = src0->ne[1];
const uint32_t n_head = src1->ne[1];
const uint32_t n_group = src4->ne[1];
const uint32_t n_tok = src1->ne[2];
const uint32_t n_seq = src1->ne[3];
bool is_mamba2 = (src3->nb[1] == sizeof(float));
GGML_ASSERT(is_mamba2);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
const vk_op_ssm_scan_push_constants pc = {
(uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
(uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
(uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
(uint32_t)src3->nb[1],
(uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
(uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
(uint32_t)s_off,
n_head, head_dim, n_group, n_tok
};
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}
vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
if (ctx->device->uma) {
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
srcs_uma[i] = d_srcs[i] != nullptr;
}
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
dst_uma = d_D != nullptr;
}
if (!dst_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
if (!srcs_uma[i]) {
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
}
}
size_t dst_size = ggml_nbytes(dst);
size_t src_sizes[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_sizes[i] = ggml_nbytes(dst->src[i]);
}
std::array<uint32_t, 3> elements;
const int splitH = 16;
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
const uint32_t num_workgroups_y = n_seq;
elements = { num_workgroups_x, num_workgroups_y, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, pc, elements);
}
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src1->nb[1],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
(uint32_t)src1->ne[0],
(uint32_t)src0->ne[0],
(uint32_t)src0->ne[1],
(uint32_t)dst->ne[1],
(uint32_t)dst->ne[2],
}, dryrun);
}
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * x = dst->src[0];
const ggml_tensor * g = dst->src[1];
@@ -9423,6 +9617,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
}
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1];
const int n_expert_used = weights->ne[1];
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}
ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
vk_buffer d_logits = nullptr;
size_t logits_buf_offset = 0;
vk_buffer d_weights = nullptr;
size_t weights_buf_offset = 0;
vk_buffer d_ids = nullptr;
size_t ids_buf_offset = 0;
bool logits_uma = false;
bool weights_uma = false;
bool ids_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
logits_uma = d_logits != nullptr;
weights_uma = d_weights != nullptr;
ids_uma = d_ids != nullptr;
}
if (!logits_uma) {
d_logits = logits_buf_ctx->dev_buffer;
logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
GGML_ASSERT(d_logits != nullptr);
}
if (!weights_uma) {
d_weights = weights_buf_ctx->dev_buffer;
weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
GGML_ASSERT(d_weights != nullptr);
}
if (!ids_uma) {
d_ids = ids_buf_ctx->dev_buffer;
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
GGML_ASSERT(d_ids != nullptr);
}
vk_op_topk_moe_push_constants pc;
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
GGML_ASSERT(n_expert_used <= n_experts);
const uint32_t rows_per_block = 4;
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
}, pc, elements);
}
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
@@ -10859,6 +11134,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
@@ -11006,11 +11283,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
}
// Add the last fused node and all fused source nodes to the unsynchronized list.
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
ctx->unsynced_nodes_written.push_back(last_node);
// Add all fused nodes to the unsynchronized lists.
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
ctx->unsynced_nodes_written.push_back(cur_node);
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
continue;
@@ -11177,7 +11454,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_SOFT_MAX:
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
if (ctx->num_additional_fused_ops) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
} else {
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
}
break;
case GGML_OP_SOFT_MAX_BACK:
@@ -11276,6 +11557,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_SSM_SCAN:
ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_SSM_CONV:
ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
@@ -11387,6 +11678,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
@@ -11961,6 +12254,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
return true;
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
int node_idx, bool with_norm) {
if (with_norm) {
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
return false;
}
}
} else {
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
return false;
}
}
}
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];
float max_bias = op_params[1];
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
return false;
}
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
return false;
}
// Check that the nodes don't have any unexpected uses
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
// softmax is used by reshape and argsort
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
reshape1->src[0] != softmax ||
argsort->src[0] != softmax) {
return false;
}
// reshape is used by get_rows
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
get_rows->src[0] != reshape1) {
return false;
}
// argsort is used by view
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
view->src[0] != argsort) {
return false;
}
// view is written (via argsort), we can skip checking it
if (with_norm) {
// get_rows is used by reshape
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
reshape5->src[0] != get_rows) {
return false;
}
// reshape is used by sum_rows and div
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
sum_rows->src[0] != reshape5 ||
div->src[0] != reshape5) {
return false;
}
// sum_rows is used by div
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
div->src[1] != sum_rows) {
return false;
}
// div/reshape are written
if (reshape8->src[0] != div) {
return false;
}
}
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
ctx->device->disable_fusion) {
return false;
}
return true;
}
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
const ggml_tensor *first_node = cgraph->nodes[node_idx];
@@ -12036,6 +12443,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12133,6 +12544,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
}
}
@@ -12140,10 +12555,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = (submitted_nodes >= nodes_per_submit) ||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(i + ctx->num_additional_fused_ops == last_node) ||
(i + ctx->num_additional_fused_ops >= last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending);
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
if (vk_perf_logger_enabled) {
if (ctx->compute_ctx.expired()) {
@@ -12264,6 +12679,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
// Avoid reordering topk_moe_norm
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
bool is_topk_moe_norm = true;
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
is_topk_moe_norm = false;
}
}
if (is_topk_moe_norm) {
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
continue;
}
}
// First, grab the next unused node.
current_set.push_back(first_unused);
@@ -12660,6 +13094,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
switch (op->src[1]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
// supported in scalar and coopmat2 paths
@@ -12867,6 +13302,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
return false;
}
}
if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
return false;
}
if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
return false;
}
const uint32_t d_state = op->src[0]->ne[0];
const uint32_t head_dim = op->src[0]->ne[1];
bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
if (!is_mamba2) {
return false;
}
if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
const uint32_t SPLIT_H = 16;
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
return false;
}
return true;
}
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
@@ -13211,14 +13687,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
struct ggml_context * ggml_ctx = ggml_init(iparams);
std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<size_t, GGML_MAX_SRC> src_size = {};
std::array<void *, GGML_MAX_SRC> src_buffer = {};
const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
struct ggml_tensor * tensor_clone = nullptr;
for (int i = 0; i < 6; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
@@ -13525,6 +14001,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[2]);
} else if (tensor->op == GGML_OP_ADD_ID) {
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
} else if (tensor->op == GGML_OP_SSM_SCAN) {
tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_SSM_CONV) {
tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
}
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@@ -13546,7 +14027,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
memcpy(comp_result, tensor_clone->data, comp_size);
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
for (int i = 0; i < 6; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (src_buffer[i] != nullptr) {
free(src_buffer[i]);
}
@@ -1,6 +1,18 @@
#include "types.glsl"
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
vec4 block;
};
float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const vec4 v = bl.block;
const uint idx = coordInBlock[1];
const f16vec4 vf16 = f16vec4(v);
return vf16[idx];
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
};
@@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#elif defined(DATA_A_F32)
#define dequantFuncA dequantFuncF32
#endif
@@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
#if defined(DATA_A_F32)
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
#elif defined(A_TYPE_PACKED16)
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4
#define BLOCK_BYTE_SIZE 16
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
// iqs is currently always zero in the flash attention shaders
if (binding_idx == BINDING_IDX_K) {
return k_packed.k_data_packed[a_offset + ib];
} else {
return v_packed.v_data_packed[a_offset + ib];
}
}
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
+30 -20
View File
@@ -313,12 +313,12 @@ void main() {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
ACC_TYPE sums[WMITER * TM * WNITER * TN];
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
FLOAT_TYPE_VEC2 cache_b[TN];
FLOAT_TYPE_VEC2 cache_b;
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
}
#endif
@@ -360,20 +360,22 @@ void main() {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
}
}
}
}
}
#endif
@@ -388,8 +390,9 @@ void main() {
}
}
#else
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
}
#endif
#endif
@@ -463,14 +466,21 @@ void main() {
const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
#ifdef MUL_MAT_ID
if (dr_warp + cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
if (dr_warp + 2 * cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
}
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
}
#endif // MUL_MAT_ID
}
@@ -0,0 +1,44 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
layout(binding = 2) buffer Dst { float dst[]; };
layout(push_constant) uniform PushConstants {
uint nb01; uint nb02;
uint nb11;
uint dst_nb0; uint dst_nb1; uint dst_nb2;
uint nc; uint ncs; uint nr; uint n_t; uint n_s;
};
void main() {
const uint global_thread_id = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i3 = gl_WorkGroupID.z;
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
return;
}
const uint i1 = global_thread_id;
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
const uint src1_base = i1 * (nb11 / 4);
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
float sum = 0.0;
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
const uint src0_idx = src0_base + i0;
const uint src1_idx = src1_base + i0;
sum += src0[src0_idx] * src1[src1_idx];
}
dst[dst_idx] = sum;
}
@@ -0,0 +1,125 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.glsl"
layout(constant_id = 0) const uint D_STATE = 128;
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 2) const uint SPLIT_H = 16;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float s0[]; };
layout(binding = 1) readonly buffer Src1 { float x[]; };
layout(binding = 2) readonly buffer Src2 { float dt[]; };
layout(binding = 3) readonly buffer Src3 { float A[]; };
layout(binding = 4) readonly buffer Src4 { float B[]; };
layout(binding = 5) readonly buffer Src5 { float C[]; };
layout(binding = 6) readonly buffer Src6 { int ids[]; };
layout(binding = 7) buffer Dst { float d[]; };
layout(push_constant) uniform PushConstants {
uint nb02; uint nb03; uint nb12; uint nb13;
uint nb21; uint nb22; uint nb31;
uint nb42; uint nb43; uint nb52; uint nb53;
uint s_off;
uint n_head;
uint d_head;
uint n_group;
uint n_tok;
};
float softplus(float x) {
if (x <= 20.0) {
return log(1.0 + exp(x));
} else {
return x;
}
}
shared float stateC[SPLIT_H * D_STATE];
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
const uint seq_idx = gl_WorkGroupID.y;
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
const uint A_base_idx = (head_idx * nb31) / 4;
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint stride_x = nb12 / 4;
const uint stride_dt = nb21 / 4;
const uint stride_B = nb42 / 4;
const uint stride_C = nb52 / 4;
const uint stride_y = n_head * d_head;
float state[SPLIT_H];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
state[j] = s0[s0_base_idx + j * D_STATE + tid];
}
for (uint i = 0; i < n_tok; i++) {
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
const float dA = exp(dt_soft_plus * A[A_base_idx]);
const float B_val = B[B_base_idx + i * stride_B + tid];
const float C_val = C[C_base_idx + i * stride_C + tid];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
state[j] = (state[j] * dA) + (B_val * x_dt);
stateC[j * D_STATE + tid] = state[j] * C_val;
}
barrier();
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % (w >> 1)) +
(D_STATE * (tid / (w >> 1))) +
j * D_STATE * (D_STATE / (w >> 1));
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + (w >> 1)];
}
}
barrier();
}
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
uint lane = tid % SUBGROUP_SIZE;
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = stateC[idx];
}
}
barrier();
}
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
d[s_base_idx + j * D_STATE + tid] = state[j];
}
}
@@ -0,0 +1,139 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
if (row >= n_rows) {
return;
}
const uint logits_offset = n_experts * row;
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
float logits_r[experts_per_thread];
const float INFINITY = 1.0 / 0.0;
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + gl_LocalInvocationID.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
}
float max_val = logits_r[0];
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
}
max_val = subgroupMax(max_val);
float wt[experts_per_thread];
float tmp = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = exp(val - max_val);
tmp += wt[i];
}
tmp = subgroupAdd(tmp);
const float inv_sum = 1.0f / tmp;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
}
// at this point, each thread holds a portion of softmax,
// we do the argmax reduce over n_expert_used, each time marking
// the expert weight as -inf to exclude from the next iteration
float wt_sum = 0.f;
float output_weights[experts_per_thread];
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
[[unroll]]
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = subgroupShuffleXor(max_val, mask);
const uint expert = subgroupShuffleXor(max_expert, mask);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[ids_offset + k] = max_expert;
if (with_norm) {
wt_sum += max_val;
}
}
}
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
output_weights[i] *= inv_sum;
}
}
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
if (idx < n_expert_used) {
weights[weights_offset + idx] = output_weights[i];
}
}
}
@@ -611,9 +611,6 @@ void process_shaders() {
}
for (const auto& tname : type_names) {
if (tname == "f32") {
continue;
}
if (tname == "bf16") continue;
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -630,7 +627,7 @@ void process_shaders() {
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
@@ -639,7 +636,7 @@ void process_shaders() {
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
@@ -919,6 +916,12 @@ void process_shaders() {
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_moe_f32", "topk_moe.comp", {});
for (auto &c : compiles) {
c.wait();
}
@@ -962,7 +965,7 @@ void write_output_files() {
}
std::string suffixes[2] = {"_f32", "_f16"};
for (auto op : {"add", "sub", "mul", "div", "add_rms"}) {
for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
+61 -1
View File
@@ -1144,9 +1144,13 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"EXP",
"GELU_ERF",
"XIELU",
"FLOOR",
"CEIL",
"ROUND",
"TRUNC",
};
static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20");
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
@@ -2749,6 +2753,62 @@ static struct ggml_tensor * ggml_glu_impl(
return result;
}
// ggml_floor
struct ggml_tensor * ggml_floor(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR);
}
struct ggml_tensor * ggml_floor_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR);
}
// ggml_ceil
struct ggml_tensor * ggml_ceil(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL);
}
struct ggml_tensor * ggml_ceil_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL);
}
//ggml_round
struct ggml_tensor * ggml_round(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_ROUND);
}
struct ggml_tensor * ggml_round_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ROUND);
}
//ggml_trunc
struct ggml_tensor * ggml_trunc(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_TRUNC);
}
struct ggml_tensor * ggml_trunc_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TRUNC);
}
struct ggml_tensor * ggml_glu(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -91,6 +91,7 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
tensor.tensor_type not in (
gguf.GGMLQuantizationType.F32,
gguf.GGMLQuantizationType.F16,
gguf.GGMLQuantizationType.BF16,
):
raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
logger.info(f"* Preparing to convert from {file_endian} to {order}")
@@ -148,6 +149,11 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
# restore old shape in case it's ever used
tensor.data.resize(oldshape)
elif tensor.tensor_type == gguf.GGMLQuantizationType.BF16:
# Special case for BF16
# It is 2-bytes data, but by default view loads it as 1-byte data.
# Change to correct view before byteswapping.
tensor.data.view(dtype=np.uint16).byteswap(inplace=True)
else:
# Handle other tensor types
tensor.data.byteswap(inplace=True)
+5
View File
@@ -5,6 +5,7 @@
#include <map>
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" },
@@ -275,6 +276,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
};
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
{
LLM_ARCH_CLIP,
{},
},
{
LLM_ARCH_LLAMA,
{
+1
View File
@@ -9,6 +9,7 @@
//
enum llm_arch {
LLM_ARCH_CLIP,
LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_DECI,
+1 -1
View File
@@ -123,7 +123,7 @@ private:
uint32_t n_seq_max;
uint32_t n_outputs;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
+2 -1
View File
@@ -2346,7 +2346,8 @@ llama_context * llama_init_from_model(
return nullptr;
}
if (params.pooling_type != model->hparams.pooling_type) {
if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
params.pooling_type != model->hparams.pooling_type) {
//user-specified pooling-type is different from the model default
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
model->hparams.pooling_type, params.pooling_type);
+74 -43
View File
@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
}
}
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
const char * swa_type_str = "unknown";
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
};
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
for (int h = 0; h < 1; ++h) {
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
const llama_pos p1 = ubatch->pos[i1];
float * data = (float *) kq_mask->data;
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
// [TAG_NO_CACHE_ISWA]
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
for (int h = 0; h < 1; ++h) {
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
for (int i0 = 0; i0 < n_tokens; ++i0) {
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
for (int i0 = 0; i0 < n_tokens; ++i0) {
const llama_seq_id s0 = ubatch->seq_id[i0][0];
const llama_pos p0 = ubatch->pos[i0];
// mask different sequences
if (s0 != s1) {
continue; // skip different sequences
continue;
}
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
continue; // skip future tokens for causal attention
// mask future tokens
if (cparams.causal_attn && p0 > p1) {
continue;
}
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
// continue; // skip masked tokens for SWA
//}
// TODO: reimplement this like in llama_kv_cache_unified
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
} else {
f = 0.0f;
// apply SWA if any
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
continue;
}
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
}
}
};
{
GGML_ASSERT(self_kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
float * data = (float *) self_kq_mask->data;
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
if (debug) {
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
}
}
if (debug) {
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(self_kq_mask_swa);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
float * data = (float *) self_kq_mask_swa->data;
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
fill_mask(data, hparams.n_swa, hparams.swa_type);
if (debug) {
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
}
}
}
@@ -1299,12 +1321,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
const auto n_kv = k->ne[1];
ggml_tensor * cur;
// TODO: replace hardcoded padding with ggml-provided padding
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
if (cparams.flash_attn && kq_b == nullptr) {
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
if (v_trans) {
@@ -1419,10 +1438,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->kq_mask);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask);
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
} else {
inp->self_kq_mask_swa = nullptr;
inp->self_kq_mask_swa_cnv = nullptr;
}
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
}
@@ -1447,7 +1476,9 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto & kq_mask = inp->get_kq_mask();
const bool is_swa = hparams.is_swa(il);
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
+7 -3
View File
@@ -257,10 +257,14 @@ public:
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
// n_tokens == n_batch
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
const llama_hparams hparams;
const llama_cparams cparams;
+41 -41
View File
@@ -114,6 +114,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
case LLM_TYPE_A13B: return "A13B";
case LLM_TYPE_7B_A1B: return "7B.A1B";
case LLM_TYPE_8B_A1B: return "8B.A1B";
case LLM_TYPE_21B_A3B: return "21B.A3B";
case LLM_TYPE_30B_A3B: return "30B.A3B";
@@ -421,11 +422,8 @@ struct llama_model::impl {
llama_mlocks mlock_bufs;
llama_mlocks mlock_mmaps;
// contexts where the model tensors metadata is stored
std::vector<ggml_context_ptr> ctxs;
// the model memory buffers for the tensor data
std::vector<ggml_backend_buffer_ptr> bufs;
// contexts where the model tensors metadata is stored as well ass the corresponding buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
buft_list_t cpu_buft_list;
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
@@ -478,7 +476,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_GENERAL_NAME, name, false);
// everything past this point is not vocab-related
if (hparams.vocab_only) {
// for CLIP models, we only need to load tensors, no hparams
if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) {
return;
}
@@ -1845,8 +1844,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
// TODO: Add llm type label (not sure this is useful)
switch (hparams.n_embd) {
case 1536: type = LLM_TYPE_7B_A1B; break;
case 2048: case 2560: type = LLM_TYPE_3B; break;
case 4096: type = LLM_TYPE_32B; break;
default: type = LLM_TYPE_UNKNOWN;
}
@@ -2181,7 +2182,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
max_n_tensors += n_layer*2; // duplicated rope freq tensors
const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
@@ -2196,12 +2204,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
throw std::runtime_error(format("failed to create ggml context"));
}
ctx_map[buft] = ctx;
pimpl->ctxs.emplace_back(ctx);
ctx_map.emplace(buft, ctx);
return ctx;
}
return it->second;
return it->second.get();
};
const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED;
@@ -6036,16 +6043,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
pimpl->mappings.reserve(ml.mappings.size());
// create the backend buffers
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_bufs;
ctx_bufs.reserve(ctx_map.size());
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_buf_maps;
ctx_buf_maps.reserve(ctx_map.size());
// Ensure we have enough capacity for the maximum backend buffer we will potentially create
const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
pimpl->bufs.reserve(n_max_backend_buffer);
pimpl->ctxs_bufs.reserve(n_max_backend_buffer);
for (auto & it : ctx_map) {
ggml_backend_buffer_type_t buft = it.first;
ggml_context * ctx = it.second;
for (auto & [buft, ctx_ptr] : ctx_map) {
ggml_context * ctx = ctx_ptr.get();
// skip contexts without tensors
if (ggml_get_first_tensor(ctx) == nullptr) {
@@ -6069,6 +6075,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
ggml_backend_buffer_t buf = nullptr;
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
// only the mmap region containing the tensors in the model is mapped to the backend buffer
@@ -6081,20 +6088,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
continue;
}
const size_t max_size = ggml_get_max_tensor_size(ctx);
ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
if (buf == nullptr) {
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
}
pimpl->bufs.emplace_back(buf);
buf_map.emplace(idx, buf);
}
}
else {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (buf == nullptr) {
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
}
pimpl->bufs.emplace_back(buf);
if (use_mlock && ggml_backend_buffer_is_host(buf)) {
pimpl->mlock_bufs.emplace_back(new llama_mlock);
auto & mlock_buf = pimpl->mlock_bufs.back();
@@ -6105,10 +6110,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
buf_map.emplace(idx, buf);
}
}
if (pimpl->bufs.empty()) {
throw std::runtime_error("failed to allocate buffer");
}
pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf);
for (auto & buf : buf_map) {
// indicate that this buffer contains weights
@@ -6116,7 +6118,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
}
ctx_bufs.emplace_back(ctx, buf_map);
ctx_buf_maps.emplace_back(ctx, buf_map);
}
if (llama_supports_gpu_offload()) {
@@ -6134,22 +6136,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
// print memory requirements per buffer type
for (auto & buf : pimpl->bufs) {
for (auto & [_, buf] : pimpl->ctxs_bufs) {
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
}
// populate tensors_by_name
for (auto & ctx : pimpl->ctxs) {
for (auto & [ctx, _] : pimpl->ctxs_bufs) {
for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
tensors_by_name.emplace_back(ggml_get_name(cur), cur);
}
}
// load tensor data
for (auto & it : ctx_bufs) {
ggml_context * ctx = it.first;
auto & bufs = it.second;
if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
for (auto & [ctx, buf_map] : ctx_buf_maps) {
if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
return false;
}
}
@@ -6189,8 +6189,8 @@ size_t llama_model::n_devices() const {
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : pimpl->bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
for (const auto & [_, buf] : pimpl->ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
}
return ret;
}
@@ -11358,8 +11358,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
}
};
struct llm_build_gemma_embedding_iswa : public llm_graph_context {
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
struct llm_build_gemma_embedding : public llm_graph_context {
llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
ggml_tensor * cur;
@@ -11376,8 +11376,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
auto * inp_attn = build_attn_inp_kv_iswa();
auto * inp_attn = build_attn_inp_no_cache();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -19378,7 +19377,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC:
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
case LLM_ARCH_LLADA_MOE:
@@ -19671,7 +19670,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break;
case LLM_ARCH_GEMMA_EMBEDDING:
{
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
} break;
case LLM_ARCH_STARCODER2:
{
@@ -20014,6 +20013,7 @@ int32_t llama_n_head(const llama_model * model) {
llama_rope_type llama_model_rope_type(const llama_model * model) {
switch (model->arch) {
// these models do not use RoPE
case LLM_ARCH_CLIP:
case LLM_ARCH_GPT2:
case LLM_ARCH_GPTJ:
case LLM_ARCH_MPT:
+1
View File
@@ -107,6 +107,7 @@ enum llm_type {
LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_A13B,
LLM_TYPE_7B_A1B,
LLM_TYPE_8B_A1B, // lfm2moe
LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_30B_A3B,
+7 -1
View File
@@ -701,6 +701,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
});
}
bool is_clip_model = false;
for (const auto * it : tensors) {
const struct ggml_tensor * tensor = it->tensor;
@@ -714,12 +715,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
qs.has_output = true;
}
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
}
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
// sanity checks for models that have attention layers
if (qs.n_attention_wv != 0)
if (qs.n_attention_wv != 0 && !is_clip_model)
{
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
// attention layers have a non-zero number of kv heads
@@ -881,6 +884,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
// do not quantize specific multimodal tensors
quantize &= name.find(".position_embd.") == std::string::npos;
ggml_type new_type;
void * new_data;
size_t new_size;
+4
View File
@@ -124,6 +124,9 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
} catch(const std::exception & e) {
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
}
if (model.arch == LLM_ARCH_CLIP) {
throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
}
try {
model.load_vocab(ml);
} catch(const std::exception & e) {
@@ -312,6 +315,7 @@ struct llama_model * llama_model_load_from_splits(
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
return nullptr;
}
splits.reserve(n_paths);
for (size_t i = 0; i < n_paths; ++i) {
splits.push_back(paths[i]);
}
+176 -5
View File
@@ -3759,6 +3759,130 @@ struct test_clamp : public test_case {
}
};
// GGML_OP_FLOOR
struct test_floor : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_floor(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 2, 2, 2})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_floor(ctx, a);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.0f, 10.0f);
}
}
};
// GGML_OP_CEIL
struct test_ceil : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_ceil(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 2, 2, 2})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_ceil(ctx, a);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.0f, 10.0f);
}
}
};
// GGML_OP_ROUND
struct test_round : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_round(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 2, 2, 2})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_round(ctx, a);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.0f, 10.0f);
}
}
};
// GGML_OP_TRUNC
struct test_trunc : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_trunc(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 2, 2, 2})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_trunc(ctx, a);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.0f, 10.0f);
}
}
};
// GGML_OP_DIAG_MASK_INF
struct test_diag_mask_inf : public test_case {
const ggml_type type;
@@ -4588,20 +4712,31 @@ struct test_topk_moe: public test_case {
struct test_sum : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const std::array<int64_t, 4> permute;
bool _use_permute;
std::string vars() override {
return VARS_TO_STR2(type, ne);
std::string v = VARS_TO_STR2(type, ne);
if (_use_permute) v += "," + VAR_TO_STR(permute);
return v;
}
test_sum(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3})
: type(type), ne(ne) {}
std::array<int64_t, 4> ne = {10, 5, 4, 3},
std::array<int64_t, 4> permute = {0, 0, 0, 0})
: type(type), ne(ne), permute(permute),
_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
if (_use_permute) {
a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
ggml_set_name(a, "a_permuted");
}
ggml_tensor * out = ggml_sum(ctx, a);
ggml_set_name(out, "out");
@@ -6354,6 +6489,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
#if 0
{
// Test paths in OpenCL
std::vector<int> ns = {32, 64, 128, 256, 512, 1024, 4096};
std::vector<int> ks = {896, 1536, 4096};
for (auto n : ns) {
for (auto k : ks) {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, n, k, {1, 1}, {1, 1}));
}
}
}
#endif
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
@@ -6561,6 +6709,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cos (type));
test_cases.emplace_back(new test_clamp (type));
test_cases.emplace_back(new test_leaky_relu(type));
test_cases.emplace_back(new test_floor (type));
test_cases.emplace_back(new test_ceil (type));
test_cases.emplace_back(new test_round (type));
test_cases.emplace_back(new test_trunc (type));
test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_log (type, {7, 1, 5, 3}));
@@ -6568,6 +6720,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
}
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
@@ -6724,6 +6880,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
@@ -6734,6 +6893,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
@@ -6779,7 +6939,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
// run fewer test cases permuted
@@ -6911,7 +7071,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
// qwen3-30b-a3b
for (int bs : {1, 4, 8, 32, 64, 128, 512}) {
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
@@ -6919,6 +7079,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
}
}
}
// gpt-oss-20b
for (int bs : {1, 4, 8, 512}) {
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
@@ -6952,6 +7121,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
+24
View File
@@ -301,6 +301,30 @@ static void test_simple_grammar() {
"0123",
}
);
test_schema(
"min 1 max 900719925474091",
// Schema
R"""({
"type": "integer",
"exclusiveMinimum": 0,
"maximum": 900719925474091
})""",
// Passing strings
{
"1",
"2",
"10",
"900719925474090",
"900719925474091",
},
// Failing strings
{
"0",
"01",
"900719925474092",
"9007199254740910",
}
);
test_schema(
"min -1 max 1",
R"""({
+2
View File
@@ -30,6 +30,7 @@
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
// vision-specific
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
#define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size"
@@ -48,6 +49,7 @@
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
// audio-specific
#define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
+15 -3
View File
@@ -2221,15 +2221,27 @@ struct clip_model_loader {
// projector type
std::string proj_type;
{
// default key
get_string(KEY_PROJ_TYPE, proj_type, false);
if (!proj_type.empty()) {
model.proj_type = clip_projector_type_from_string(proj_type);
// for models with mixed modalities
if (proj_type.empty()) {
if (modality == CLIP_MODALITY_VISION) {
get_string(KEY_VISION_PROJ_TYPE, proj_type, false);
} else if (modality == CLIP_MODALITY_AUDIO) {
get_string(KEY_AUDIO_PROJ_TYPE, proj_type, false);
} else {
GGML_ABORT("unknown modality");
}
}
model.proj_type = clip_projector_type_from_string(proj_type);
if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
}
// correct arch for multimodal models
// correct arch for multimodal models (legacy method)
if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
model.proj_type = modality == CLIP_MODALITY_VISION
? PROJECTOR_TYPE_QWEN25VL
+1 -33
View File
@@ -137,7 +137,6 @@ struct rpc_server_params {
bool use_cache = false;
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
std::vector<std::string> devices;
std::vector<size_t> dev_mem;
};
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
@@ -148,7 +147,6 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
fprintf(stderr, " -c, --cache enable local file cache\n");
fprintf(stderr, "\n");
}
@@ -197,23 +195,6 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
}
} else if (arg == "-c" || arg == "--cache") {
params.use_cache = true;
} else if (arg == "-m" || arg == "--mem") {
if (++i >= argc) {
return false;
}
const std::regex regex{ R"([,/]+)" };
std::string mem_str = argv[i];
std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
std::sregex_token_iterator end;
for ( ; iter != end; ++iter) {
try {
size_t mem = std::stoul(*iter) * 1024 * 1024;
params.dev_mem.push_back(mem);
} catch (const std::exception & ) {
fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
return false;
}
}
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv, params);
exit(0);
@@ -293,18 +274,6 @@ int main(int argc, char * argv[]) {
return 1;
}
std::string endpoint = params.host + ":" + std::to_string(params.port);
std::vector<size_t> free_mem, total_mem;
for (size_t i = 0; i < devices.size(); i++) {
if (i < params.dev_mem.size()) {
free_mem.push_back(params.dev_mem[i]);
total_mem.push_back(params.dev_mem[i]);
} else {
size_t free, total;
ggml_backend_dev_memory(devices[i], &free, &total);
free_mem.push_back(free);
total_mem.push_back(total);
}
}
const char * cache_dir = nullptr;
std::string cache_dir_str;
if (params.use_cache) {
@@ -328,7 +297,6 @@ int main(int argc, char * argv[]) {
return 1;
}
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
devices.data(), free_mem.data(), total_mem.data());
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data());
return 0;
}
Binary file not shown.
+18 -10
View File
@@ -1585,23 +1585,31 @@ struct server_prompt_cache {
}
}
// average size per token
const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
// dynamically increase the token limit if it can fit in the memory limit
const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size/size_per_token) : limit_tokens;
if (limit_tokens > 0) {
while (states.size() > 1 && n_tokens() > limit_tokens) {
while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
if (states.empty()) {
break;
}
SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
states.pop_front();
}
}
SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n",
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens);
SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
for (const auto & state : states) {
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
(const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
}
}
};
@@ -3804,7 +3812,7 @@ struct server_context {
if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}
@@ -3831,14 +3839,14 @@ struct server_context {
{
const auto token = slot.prompt.tokens[i];
const auto piece = common_token_to_piece(ctx, token);
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
ss0 << piece;
st0 << std::setw(8) << token;
}
{
const auto token = slot.task->tokens[i];
const auto piece = common_token_to_piece(ctx, token);
const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]";
ss1 << piece;
st1 << std::setw(8) << token;
}
@@ -3852,7 +3860,7 @@ struct server_context {
}
if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
const auto it = std::find_if(
@@ -4020,7 +4028,7 @@ struct server_context {
}
}
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
+3 -2
View File
@@ -1237,9 +1237,10 @@ public:
// allowed to resize ^ ^
// disallowed to resize ^ ^ ^
if (n > 0) {
llama_token last_token = tokens[n - 1];
// make sure we never remove tokens in the middle of an image
if (last_token == LLAMA_TOKEN_NULL) {
// note that the case where we keep a full image at the end is allowed:
// tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL
if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) {
find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk
}
}
+527
View File
@@ -50,6 +50,7 @@
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
"http-server": "^14.1.1",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3",
"playwright": "^1.53.0",
@@ -2979,6 +2980,13 @@
"node": ">=4"
}
},
"node_modules/async": {
"version": "3.2.6",
"resolved": "https://registry.npmjs.org/async/-/async-3.2.6.tgz",
"integrity": "sha512-htCUDlxyyCLMgaM3xXg0C0LW2xqfuQ6p05pCEIsXuyQ+a1koYKTuBMzRNwmybfLgvJDMd0r1LTn4+E0Ti6C2AA==",
"dev": true,
"license": "MIT"
},
"node_modules/axe-core": {
"version": "4.10.3",
"resolved": "https://registry.npmjs.org/axe-core/-/axe-core-4.10.3.tgz",
@@ -3015,6 +3023,19 @@
"dev": true,
"license": "MIT"
},
"node_modules/basic-auth": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/basic-auth/-/basic-auth-2.0.1.tgz",
"integrity": "sha512-NF+epuEdnUYVlGuhaxbbq+dvJttwLnGY+YixlXlME5KpQ5W3CnXA5cVTneY3SPbPDRkcjMbifrwmFYcClgOZeg==",
"dev": true,
"license": "MIT",
"dependencies": {
"safe-buffer": "5.1.2"
},
"engines": {
"node": ">= 0.8"
}
},
"node_modules/better-opn": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/better-opn/-/better-opn-3.0.2.tgz",
@@ -3125,6 +3146,37 @@
"node": ">=8"
}
},
"node_modules/call-bind-apply-helpers": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz",
"integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0",
"function-bind": "^1.1.2"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/call-bound": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz",
"integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bind-apply-helpers": "^1.0.2",
"get-intrinsic": "^1.3.0"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/callsites": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
@@ -3335,6 +3387,16 @@
"node": ">= 0.6"
}
},
"node_modules/corser": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/corser/-/corser-2.0.1.tgz",
"integrity": "sha512-utCYNzRSQIZNPIcGZdQc92UVJYAhtGAteCFg0yRaFm8f0P+CPtyGyHXJcGXnffjCybUCEx3FQ2G7U3/o9eIkVQ==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4.0"
}
},
"node_modules/cross-spawn": {
"version": "7.0.6",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz",
@@ -3520,6 +3582,21 @@
"dev": true,
"license": "MIT"
},
"node_modules/dunder-proto": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz",
"integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bind-apply-helpers": "^1.0.1",
"es-errors": "^1.3.0",
"gopd": "^1.2.0"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/enhanced-resolve": {
"version": "5.18.2",
"resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.2.tgz",
@@ -3547,6 +3624,26 @@
"url": "https://github.com/fb55/entities?sponsor=1"
}
},
"node_modules/es-define-property": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz",
"integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
}
},
"node_modules/es-errors": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz",
"integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
}
},
"node_modules/es-module-lexer": {
"version": "1.7.0",
"resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.7.0.tgz",
@@ -3554,6 +3651,19 @@
"dev": true,
"license": "MIT"
},
"node_modules/es-object-atoms": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz",
"integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/es-toolkit": {
"version": "1.39.7",
"resolved": "https://registry.npmjs.org/es-toolkit/-/es-toolkit-1.39.7.tgz",
@@ -3885,6 +3995,13 @@
"node": ">=0.10.0"
}
},
"node_modules/eventemitter3": {
"version": "4.0.7",
"resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz",
"integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==",
"dev": true,
"license": "MIT"
},
"node_modules/expect-type": {
"version": "1.2.2",
"resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.2.2.tgz",
@@ -4058,6 +4175,27 @@
"dev": true,
"license": "ISC"
},
"node_modules/follow-redirects": {
"version": "1.15.11",
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.11.tgz",
"integrity": "sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==",
"dev": true,
"funding": [
{
"type": "individual",
"url": "https://github.com/sponsors/RubenVerborgh"
}
],
"license": "MIT",
"engines": {
"node": ">=4.0"
},
"peerDependenciesMeta": {
"debug": {
"optional": true
}
}
},
"node_modules/fsevents": {
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
@@ -4073,6 +4211,55 @@
"node": "^8.16.0 || ^10.6.0 || >=11.0.0"
}
},
"node_modules/function-bind": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz",
"integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==",
"dev": true,
"license": "MIT",
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/get-intrinsic": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz",
"integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bind-apply-helpers": "^1.0.2",
"es-define-property": "^1.0.1",
"es-errors": "^1.3.0",
"es-object-atoms": "^1.1.1",
"function-bind": "^1.1.2",
"get-proto": "^1.0.1",
"gopd": "^1.2.0",
"has-symbols": "^1.1.0",
"hasown": "^2.0.2",
"math-intrinsics": "^1.1.0"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/get-proto": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz",
"integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==",
"dev": true,
"license": "MIT",
"dependencies": {
"dunder-proto": "^1.0.1",
"es-object-atoms": "^1.0.0"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/glob-parent": {
"version": "6.0.2",
"resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz",
@@ -4099,6 +4286,19 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/gopd": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz",
"integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/graceful-fs": {
"version": "4.2.11",
"resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz",
@@ -4123,6 +4323,32 @@
"node": ">=8"
}
},
"node_modules/has-symbols": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz",
"integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/hasown": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz",
"integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"function-bind": "^1.1.2"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/hast-util-from-dom": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/hast-util-from-dom/-/hast-util-from-dom-5.0.1.tgz",
@@ -4363,6 +4589,16 @@
"url": "https://opencollective.com/unified"
}
},
"node_modules/he": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/he/-/he-1.2.0.tgz",
"integrity": "sha512-F/1DnUGPopORZi0ni+CvrCgHQ5FyEAHRLSApuYWMmrbSwoN2Mn/7k+Gl38gJnR7yyDZk6WLXwiGod1JOWNDKGw==",
"dev": true,
"license": "MIT",
"bin": {
"he": "bin/he"
}
},
"node_modules/highlight.js": {
"version": "11.11.1",
"resolved": "https://registry.npmjs.org/highlight.js/-/highlight.js-11.11.1.tgz",
@@ -4372,6 +4608,19 @@
"node": ">=12.0.0"
}
},
"node_modules/html-encoding-sniffer": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-3.0.0.tgz",
"integrity": "sha512-oWv4T4yJ52iKrufjnyZPkrN0CH3QnrUqdB6In1g5Fe1mia8GmF36gnfNySxoZtxD5+NmYw1EElVXiBk93UeskA==",
"dev": true,
"license": "MIT",
"dependencies": {
"whatwg-encoding": "^2.0.0"
},
"engines": {
"node": ">=12"
}
},
"node_modules/html-void-elements": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/html-void-elements/-/html-void-elements-3.0.0.tgz",
@@ -4382,6 +4631,62 @@
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/http-proxy": {
"version": "1.18.1",
"resolved": "https://registry.npmjs.org/http-proxy/-/http-proxy-1.18.1.tgz",
"integrity": "sha512-7mz/721AbnJwIVbnaSv1Cz3Am0ZLT/UBwkC92VlxhXv/k/BBQfM2fXElQNC27BVGr0uwUpplYPQM9LnaBMR5NQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"eventemitter3": "^4.0.0",
"follow-redirects": "^1.0.0",
"requires-port": "^1.0.0"
},
"engines": {
"node": ">=8.0.0"
}
},
"node_modules/http-server": {
"version": "14.1.1",
"resolved": "https://registry.npmjs.org/http-server/-/http-server-14.1.1.tgz",
"integrity": "sha512-+cbxadF40UXd9T01zUHgA+rlo2Bg1Srer4+B4NwIHdaGxAGGv59nYRnGGDJ9LBk7alpS0US+J+bLLdQOOkJq4A==",
"dev": true,
"license": "MIT",
"dependencies": {
"basic-auth": "^2.0.1",
"chalk": "^4.1.2",
"corser": "^2.0.1",
"he": "^1.2.0",
"html-encoding-sniffer": "^3.0.0",
"http-proxy": "^1.18.1",
"mime": "^1.6.0",
"minimist": "^1.2.6",
"opener": "^1.5.1",
"portfinder": "^1.0.28",
"secure-compare": "3.0.1",
"union": "~0.5.0",
"url-join": "^4.0.1"
},
"bin": {
"http-server": "bin/http-server"
},
"engines": {
"node": ">=12"
}
},
"node_modules/iconv-lite": {
"version": "0.6.3",
"resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.6.3.tgz",
"integrity": "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==",
"dev": true,
"license": "MIT",
"dependencies": {
"safer-buffer": ">= 2.1.2 < 3.0.0"
},
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/ignore": {
"version": "5.3.2",
"resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz",
@@ -5008,6 +5313,16 @@
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/math-intrinsics": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz",
"integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
}
},
"node_modules/mdast": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/mdast/-/mdast-3.0.0.tgz",
@@ -5976,6 +6291,19 @@
"url": "https://github.com/sponsors/jonschlinkert"
}
},
"node_modules/mime": {
"version": "1.6.0",
"resolved": "https://registry.npmjs.org/mime/-/mime-1.6.0.tgz",
"integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==",
"dev": true,
"license": "MIT",
"bin": {
"mime": "cli.js"
},
"engines": {
"node": ">=4"
}
},
"node_modules/min-indent": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz",
@@ -6009,6 +6337,16 @@
"node": "*"
}
},
"node_modules/minimist": {
"version": "1.2.8",
"resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz",
"integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==",
"dev": true,
"license": "MIT",
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/minipass": {
"version": "7.1.2",
"resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz",
@@ -6124,6 +6462,19 @@
"tslib": "^2.0.3"
}
},
"node_modules/object-inspect": {
"version": "1.13.4",
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
"integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/open": {
"version": "8.4.2",
"resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz",
@@ -6142,6 +6493,16 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/opener": {
"version": "1.5.2",
"resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz",
"integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==",
"dev": true,
"license": "(WTFPL OR MIT)",
"bin": {
"opener": "bin/opener-bin.js"
}
},
"node_modules/optionator": {
"version": "0.9.4",
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
@@ -6330,6 +6691,20 @@
"node": ">=18"
}
},
"node_modules/portfinder": {
"version": "1.0.38",
"resolved": "https://registry.npmjs.org/portfinder/-/portfinder-1.0.38.tgz",
"integrity": "sha512-rEwq/ZHlJIKw++XtLAO8PPuOQA/zaPJOZJ37BVuN97nLpMJeuDVLVGRwbFoBgLudgdTMP2hdRJP++H+8QOA3vg==",
"dev": true,
"license": "MIT",
"dependencies": {
"async": "^3.2.6",
"debug": "^4.3.6"
},
"engines": {
"node": ">= 10.12"
}
},
"node_modules/postcss": {
"version": "8.5.6",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz",
@@ -6680,6 +7055,22 @@
"node": ">=6"
}
},
"node_modules/qs": {
"version": "6.14.0",
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.0.tgz",
"integrity": "sha512-YWWTjgABSKcvs/nWBi9PycY/JiPJqOD4JA6o9Sej2AtvSGarXxKC3OQSk4pAarbdQlKAh5D4FCQkJNkW+GAn3w==",
"dev": true,
"license": "BSD-3-Clause",
"dependencies": {
"side-channel": "^1.1.0"
},
"engines": {
"node": ">=0.6"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/queue-microtask": {
"version": "1.2.3",
"resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
@@ -6959,6 +7350,13 @@
"url": "https://opencollective.com/unified"
}
},
"node_modules/requires-port": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/requires-port/-/requires-port-1.0.0.tgz",
"integrity": "sha512-KigOCHcocU3XODJxsu8i/j8T9tzT4adHiecwORRQ0ZZFcp7ahwXuRU1m+yuO90C5ZUyGeGfocHDI14M3L3yDAQ==",
"dev": true,
"license": "MIT"
},
"node_modules/resolve-from": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz",
@@ -7072,6 +7470,20 @@
"node": ">=6"
}
},
"node_modules/safe-buffer": {
"version": "5.1.2",
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz",
"integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==",
"dev": true,
"license": "MIT"
},
"node_modules/safer-buffer": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz",
"integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==",
"dev": true,
"license": "MIT"
},
"node_modules/scheduler": {
"version": "0.26.0",
"resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.26.0.tgz",
@@ -7079,6 +7491,13 @@
"dev": true,
"license": "MIT"
},
"node_modules/secure-compare": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/secure-compare/-/secure-compare-3.0.1.tgz",
"integrity": "sha512-AckIIV90rPDcBcglUwXPF3kg0P0qmPsPXAj6BBEENQE1p5yA1xfmDJzfi1Tappj37Pv2mVbKpL3Z1T+Nn7k1Qw==",
"dev": true,
"license": "MIT"
},
"node_modules/semver": {
"version": "7.7.2",
"resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz",
@@ -7122,6 +7541,82 @@
"node": ">=8"
}
},
"node_modules/side-channel": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz",
"integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0",
"object-inspect": "^1.13.3",
"side-channel-list": "^1.0.0",
"side-channel-map": "^1.0.1",
"side-channel-weakmap": "^1.0.2"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/side-channel-list": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz",
"integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0",
"object-inspect": "^1.13.3"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/side-channel-map": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz",
"integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bound": "^1.0.2",
"es-errors": "^1.3.0",
"get-intrinsic": "^1.2.5",
"object-inspect": "^1.13.3"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/side-channel-weakmap": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz",
"integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bound": "^1.0.2",
"es-errors": "^1.3.0",
"get-intrinsic": "^1.2.5",
"object-inspect": "^1.13.3",
"side-channel-map": "^1.0.1"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/siginfo": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz",
@@ -7904,6 +8399,18 @@
"integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==",
"license": "MIT"
},
"node_modules/union": {
"version": "0.5.0",
"resolved": "https://registry.npmjs.org/union/-/union-0.5.0.tgz",
"integrity": "sha512-N6uOhuW6zO95P3Mel2I2zMsbsanvvtgn6jVqJv4vbVcz/JN0OkL9suomjQGmWtxJQXOCqUJvquc1sMeNz/IwlA==",
"dev": true,
"dependencies": {
"qs": "^6.4.0"
},
"engines": {
"node": ">= 0.8.0"
}
},
"node_modules/unist-util-find-after": {
"version": "5.0.0",
"resolved": "https://registry.npmjs.org/unist-util-find-after/-/unist-util-find-after-5.0.0.tgz",
@@ -8073,6 +8580,13 @@
"punycode": "^2.1.0"
}
},
"node_modules/url-join": {
"version": "4.0.1",
"resolved": "https://registry.npmjs.org/url-join/-/url-join-4.0.1.tgz",
"integrity": "sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==",
"dev": true,
"license": "MIT"
},
"node_modules/util-deprecate": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz",
@@ -8447,6 +8961,19 @@
"dev": true,
"license": "MIT"
},
"node_modules/whatwg-encoding": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/whatwg-encoding/-/whatwg-encoding-2.0.0.tgz",
"integrity": "sha512-p41ogyeMUrw3jWclHWTQg1k05DSVXPLcVxRTYsXUk+ZooOCZLcoYgPZ/HL/D/N+uQPOtcp1me1WhBEaX02mhWg==",
"dev": true,
"license": "MIT",
"dependencies": {
"iconv-lite": "0.6.3"
},
"engines": {
"node": ">=12"
}
},
"node_modules/which": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",
+1
View File
@@ -52,6 +52,7 @@
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
"http-server": "^14.1.1",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3",
"playwright": "^1.53.0",
+4 -2
View File
@@ -2,8 +2,10 @@ import { defineConfig } from '@playwright/test';
export default defineConfig({
webServer: {
command: 'npm run build && npx http-server ../public -p 8181',
port: 8181
command: 'npm run build && http-server ../public -p 8181',
port: 8181,
timeout: 120000,
reuseExistingServer: false
},
testDir: 'e2e'
});
@@ -26,6 +26,7 @@
MimeTypeImage,
MimeTypeText
} from '$lib/enums/files';
import { isIMEComposing } from '$lib/utils/is-ime-composing';
interface Props {
class?: string;
@@ -97,7 +98,7 @@
}
async function handleKeydown(event: KeyboardEvent) {
if (event.key === 'Enter' && !event.shiftKey) {
if (event.key === 'Enter' && !event.shiftKey && !isIMEComposing(event)) {
event.preventDefault();
if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return;
@@ -1,6 +1,7 @@
<script lang="ts">
import { getDeletionInfo } from '$lib/stores/chat.svelte';
import { copyToClipboard } from '$lib/utils/copy';
import { isIMEComposing } from '$lib/utils/is-ime-composing';
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
import ChatMessageUser from './ChatMessageUser.svelte';
@@ -93,7 +94,9 @@
}
function handleEditKeydown(event: KeyboardEvent) {
if (event.key === 'Enter' && !event.shiftKey) {
// Check for IME composition using isComposing property and keyCode 229 (specifically for IME composition on Safari)
// This prevents saving edit when confirming IME word selection (e.g., Japanese/Chinese input)
if (event.key === 'Enter' && !event.shiftKey && !isIMEComposing(event)) {
event.preventDefault();
handleSaveEdit();
} else if (event.key === 'Escape') {
@@ -7,18 +7,19 @@
const processingState = useProcessingState();
let isCurrentConversationLoading = $derived(isLoading());
let processingDetails = $derived(processingState.getProcessingDetails());
let showSlotsInfo = $derived(isCurrentConversationLoading || config().keepStatsVisible);
let showSlotsInfo = $derived(isLoading() || config().keepStatsVisible);
// Track loading state reactively by checking if conversation ID is in loading conversations array
$effect(() => {
const keepStatsVisible = config().keepStatsVisible;
if (keepStatsVisible || isLoading()) {
if (keepStatsVisible || isCurrentConversationLoading) {
processingState.startMonitoring();
}
if (!isLoading() && !keepStatsVisible) {
if (!isCurrentConversationLoading && !keepStatsVisible) {
setTimeout(() => {
if (!config().keepStatsVisible) {
processingState.stopMonitoring();
@@ -27,18 +28,20 @@
}
});
// Update processing state from stored timings
$effect(() => {
activeConversation();
const conversation = activeConversation();
const messages = activeMessages() as DatabaseMessage[];
const keepStatsVisible = config().keepStatsVisible;
if (keepStatsVisible) {
if (keepStatsVisible && conversation) {
if (messages.length === 0) {
slotsService.clearState();
slotsService.clearConversationState(conversation.id);
return;
}
// Search backwards through messages to find most recent assistant message with timing data
// Using reverse iteration for performance - avoids array copy and stops at first match
let foundTimingData = false;
for (let i = messages.length - 1; i >= 0; i--) {
@@ -47,15 +50,18 @@
foundTimingData = true;
slotsService
.updateFromTimingData({
prompt_n: message.timings.prompt_n || 0,
predicted_n: message.timings.predicted_n || 0,
predicted_per_second:
message.timings.predicted_n && message.timings.predicted_ms
? (message.timings.predicted_n / message.timings.predicted_ms) * 1000
: 0,
cache_n: message.timings.cache_n || 0
})
.updateFromTimingData(
{
prompt_n: message.timings.prompt_n || 0,
predicted_n: message.timings.predicted_n || 0,
predicted_per_second:
message.timings.predicted_n && message.timings.predicted_ms
? (message.timings.predicted_n / message.timings.predicted_ms) * 1000
: 0,
cache_n: message.timings.cache_n || 0
},
conversation.id
)
.catch((error) => {
console.warn('Failed to update processing state from stored timings:', error);
});
@@ -64,7 +70,7 @@
}
if (!foundTimingData) {
slotsService.clearState();
slotsService.clearConversationState(conversation.id);
}
}
});

Some files were not shown because too many files have changed in this diff Show More