Compare commits

...

49 Commits

Author SHA1 Message Date
lhez 9c96465f99 opencl: enable the general fp mm for non-cont input and as a fallback for specialized kqv kernel for adreno (#18970)
* opencl: add `copy_to_contiguous` and utilize mm kernels

* opencl: only copy to cont for f32 and f16 tensors

* opencl: use cont mm for fallback when dst is large

* opencl: use nb local to copy-to-cont

* opencl: use local offset as well
2026-01-22 10:29:25 -08:00
Xuan-Son Nguyen 4e595b250a server: do not log certain endpoints (avoid log spam) (#19028) 2026-01-22 19:24:37 +01:00
Georgi Gerganov 0e4ebeb057 quant : manual overrides of tensor types take precedence (#18952) 2026-01-22 16:17:06 +02:00
Aaron Teo 8b30840703 release: update github api (#19022) 2026-01-22 21:38:02 +08:00
Xuan-Son Nguyen 9eb5bfec1a mtmd : update docs to use llama_model_n_embd_inp (#18999) 2026-01-22 14:36:32 +01:00
손희준 c6926d1d95 server: Reorder methods in server-task.cpp (#19016)
* Move `task_result_state::update_chat_msg` to match with header

* Move `server_task_result_cmpl_partial::to_json_anthropic()` to match with header

---------

Co-authored-by: openingnow <>
2026-01-22 14:36:04 +01:00
Aman Gupta b70d251076 CUDA: add gqa_ratio 4 for GLM 4.7 flash (#18953) 2026-01-22 18:51:53 +08:00
shaofeiqi 5516b9c16a opencl: add TRI op support (#18979) 2026-01-21 22:05:54 -08:00
Aleksei Nikiforov 94242a62c0 ggml-zdnn : mark zDNN buffers as non-host (#18967)
While buffers reside in host memory,
additional transformation is needed to use buffers with zDNN.

Fixes #18848
2026-01-22 01:16:21 +01:00
Pádraic Slattery 6b99a223e3 ci : update GitHub Actions versions [no ci] (#18935) 2026-01-22 00:57:18 +01:00
Mariusz Woloszyn 77078e80e5 convert : add Devstral-2 (Ministral3ForCausalLM) arch (#18972)
* Add Ministral3ForCausalLM architeture

This adds support for newer architectres like Devstral-2

* removed blank line found after function decorator

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-01-22 00:55:55 +01:00
Piotr Wilkin (ilintar) c301172f66 jinja: support none|string (#18995)
* jinja: support none|string

* Update common/jinja/value.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update tests/test-jinja.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Add as_string()

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-01-21 19:24:37 +01:00
Hendrik Erz 3802d3c78f fix: Use tabular-nums for chat message statistics (#18915)
* fix: Use `tabular-nums` for chat message statistics

* fix: Rebuild WebUI
2026-01-21 18:46:01 +01:00
Daniel Bevenius 9da3dcd753 llama : clarify nemotron-h.cpp comment about RoPE [no ci] (#18997)
This commit removes the mention of RoPE in the comment for the Q and K
computation as RoPE is not applied.
2026-01-21 18:31:34 +01:00
Jeff Bolz bd544c94a3 vulkan: Remove transfer_ctx, do everything in compute_ctx. (#18945)
* vulkan: Remove transfer_ctx, do everything in compute_ctx.

We had a bug where a set_tensor_async (using transfer_ctx) didn't get
submitted before the graph_compute (using compute_ctx) that came after
it. To avoid this sort of issue, just do everything in compute_ctx.

Remove transfer_cmd_pool, which was already unused.

* fix crash with perf logger
2026-01-21 18:01:40 +01:00
Adrien Gallouët 14be5a39b1 common : improve error message when HTTPS is missing but required (#18987)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-01-21 17:58:38 +01:00
손희준 fbbf3ad190 server: /v1/responses (partial) (#18486)
* from previous PR

* Make instruction(system) as first message

* Convert [input_message] (text/image/file)

* Rename convert_responses_to_chatcmpl(body) -> response_body

* Initial tool call support

* Erase instructions field from chatcmpl body

* Feed reasoning texts to chat template

* Use std::vector instead of opaque json array

* Make output_item.added events consistent

* Move `server_task_result_cmpl_partial::update` from header to source

* Match ID of output_item.added and .done events

* Add function_call only if there is no "fc_" prefix

* Add function call output at non-streaming API

* Test if ID is persistent

* Add doc

* Fix style - use trailing comma

* Rewrite state management

* catch up with upstream/master

* Fix style - "type" is the first item of SSE data

* Explicitly check "instructions" from response_body

* Make lambdas static

* Check if reasoning content exists

* Add `oai_resp_id` to task_result_state(also initialized at ctor), server_task_result_cmpl_partial, and server_task_result_cmpl_final

* Reject `input_file` since it is not supported by chatcmpl

* Add "fc_" prefix to non-straming function call id as coderabbit pointed out

---------

Co-authored-by: openingnow <>
2026-01-21 17:47:23 +01:00
Jeff Bolz 33f890e579 vulkan: support flash attention GQA/split_k with small batches (#18938) 2026-01-21 17:43:43 +01:00
Masato Nakasaka 067b8d7af3 Revert "vulkan: force full subgroups for flash attention to fix intel subgroup crash (#17356)" (#18831)
This reverts commit 980b7cd17e.
2026-01-21 17:13:43 +01:00
Jeff Bolz 50b7f076a5 vulkan: Use mul_mat_vec_id for small values of n (#18918)
Change ggml_vk_mul_mat_vec_id_q_f16 to loop over the batch dimension and
update the indexing calculations in get_offsets.

Mat-vec is faster than mat-mat for small values of n. We don't get the same
reuse of the weights as in the non-ID path, but with this the cost is linear
in n rather than n>1 being far slower than n==1.
2026-01-21 16:22:02 +01:00
Tarek Dakhran ad8d85bd94 memory : add llama_memory_hybrid_iswa (#18601)
* memory : add llama_memory_hybrid_iswa

* Update src/llama-memory-hybrid-iswa.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-01-21 14:30:23 +02:00
Piotr Wilkin (ilintar) 12a4a47e6a Fix GLM 4.7 Lite MoE gating func (#18980)
* Fix GLM 4.7 MoE gating func

* Update src/models/deepseek2.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-01-21 12:35:20 +01:00
Matthieu Coudron 37c35f0e1c gguf: display strerrno when cant load a model (#18884)
I've had issues loading models with llama-server:
[44039] E gguf_init_from_file: failed to open GGUF file 'mistral-7b-v0.1.Q8_0.gguf'

and I was sure it could access the file. Seems like --models-dir and
--models-presets dont interact like I thought they would but I salvaged
this snippet that helps troubleshooting
[44039] E gguf_init_from_file: failed to open GGUF file 'mistral-7b-v0.1.Q8_0.gguf' (errno No such file or directory)
2026-01-21 08:52:46 +02:00
Oliver Simons 5bd341c9a1 CUDA: Fix builds for older CCCL versions by ifdefing strided_iterator (#18964)
* CUDA: Fix builds for older CCCL versions by ifdefing strided_iterator

Strided iterator was added in [CCCL
3.1](https://github.com/NVIDIA/cccl/releases/tag/v3.1.0), which is packaged into
[CTK
13.1](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id5)

* Unindent as per code review request
2026-01-21 02:34:29 +01:00
Adrien Gallouët 1c7cf94b22 common, server : use the same User-Agent by default (#18957)
This commit also ensures that if a custom User-Agent is used, it will be
the only one sent.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-01-20 18:28:43 +01:00
Xuan-Son Nguyen 2c1f199653 cli : fix reasoning responses in CLI (#18961)
* cli : fix reasoning responses in CLI

* fix build

* fix build (2)
2026-01-20 18:23:25 +01:00
Oliver Simons d1e3556481 CUDA: Replace init_offsets kernel with iterators in cub-based argsort (#18930)
* CUDA: Replace `init_offsets` with iterators in argsort

This is a QOL improvement, saving us the cost of materializing the
iterator

* Remove unnecessary include from top-k.cu
2026-01-20 20:11:01 +08:00
Adrien Gallouët 08f3f4a8a3 ggml : cleanup path_str() (#18928)
- Remove pragmas as `std::codecvt_utf8` is not used.
- Avoid implicit `strlen()`.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-01-20 11:42:49 +01:00
Georgi Gerganov 271191906c metal : enable FA for MLA heads (#18950) 2026-01-20 12:21:28 +02:00
Daniel Bevenius 7dee9ff59a convert : use n_groups instead of hardcoded values in reshape (#18929)
* convert : use n_groups instead of hardcoded values in reshape

This commit modifies the conversion script for NemotronHModel to use
the 'n_groups' hyperparameter, and allow Python to calculate the the
last dimension, using -1, when reshaping the 'mixer.norm.weight' tensor.

* use self.n_group instead of self.hparams["n_groups"]
2026-01-20 06:55:24 +01:00
Xuan-Son Nguyen 6df686bee6 server : refactor oai_parser_opt, move it to server_chat_params (#18937)
* server_chat_params

* move chat format into CLI

* use meta whenever possible

* clean up, no more chatml fallback
2026-01-19 23:28:01 +01:00
ddh0 1706a6d7c6 convert : support Glm4MoeLite (#18936)
* initial commit for branch

* add glm-4.7-flash, move tokenizer hash

* use `glm4` pretok

* silence flake8 E302 (CI)

* apply review feedback

* add <|user|> as eog

* also add EOG `<|observation|>`

* revert llama-vocab

* inherit vocab from glm4

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-01-19 23:09:20 +01:00
Sigbjørn Skjæret 959ecf7f23 jinja : fix undefined keys and attributes and int/float as bool (#18924)
* fix undefined keys and attributes

* add falsy tests

* as_bool for integers and floats

* more falsy/truthy tests

* --typo
2026-01-19 20:29:43 +01:00
Sigbjørn Skjæret 4037093c66 ci : run test-jinja -py on high perf [no ci] (#18916) 2026-01-19 20:29:15 +01:00
Lennart Austenfeld 18361c579c server: fix memory reservations in populate_token_probs (#18787) 2026-01-19 19:13:31 +01:00
Georgi Gerganov 365a3e8c31 ggml : add ggml_build_forward_select (#18550)
* ggml : add ggml_build_forward_select

* cuda : adapt CUDA graph compat to new feature

* vulkan : update logic to handle command buffer closing

* ggml : check compute for fusion

* ggml : add comment
2026-01-19 20:03:19 +02:00
Daniel Bevenius 3d55846a5c model-conversion : add BUILD_DIR variable to run-converted-model scripts (#18927)
This commit adds a BUILD_DIR variable to the scripts used for running
converted models.

The motivation for this is that currently the `build` directory is
hardcoded and it can be useful to specify a different build directory,
with builds for different configurations.
2026-01-19 13:12:38 +01:00
Julius Tischbein 287a33017b llama : Extend fallback, fix fileno for dio file, exclude case that mmap uses dio file (#18887) 2026-01-18 18:35:57 +02:00
Francisco Herrera 293a1565dc docs: add linux to index (#18907) 2026-01-18 18:03:35 +08:00
Xuan-Son Nguyen fe44d35574 tests : add test-jinja -py option for cross-checking (#18906)
* tests : add test-jinja -py option or cross-checking

* Update tests/test-jinja.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix + add source

* SandboxedEnvironment

* fix array.map case

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-01-18 08:14:27 +01:00
Sigbjørn Skjæret bbcdac0189 jinja : fix object item order (and properly implement dictsort) (#18904)
* fix object item order

* as_ordered_object

* copy whole object
2026-01-18 03:40:06 +01:00
Sigbjørn Skjæret d03c45c9c5 jinja : attribute support for join, map and sort (#18883)
* support negative array index and default value

* attribute support (int and str) for join, map and sort

* add tests

* update CODEOWNERS

* improve fixme sorting comment
2026-01-18 02:53:01 +01:00
Sigbjørn Skjæret 10c98cbdf6 jinja : add missing tojson filter for bool (#18900)
* add missing tojson for bool

* add more literal tests
2026-01-18 01:05:09 +01:00
Sigbjørn Skjæret 420960ab92 jinja : fix lexing of float literals with sign (#18901)
* fix lexing of float literals with sign

* add test

* consume_numeric
2026-01-18 00:57:51 +01:00
Xuan-Son Nguyen f55b033ae6 jinja: correct member access rule (#18905) 2026-01-18 00:48:55 +01:00
lhez d1b4757ded opencl: fix q6_K mv for m=1 (#18893) 2026-01-17 13:50:32 -08:00
Sigbjørn Skjæret 57c0beaed0 ci : add label for jinja changes (#18903) 2026-01-17 21:52:02 +01:00
Georgi Gerganov 2fbde785bc kv-cache : optimize KQ mask construction (#18842)
* kv-cache : optimize KQ mask construction

* cont : add explanation + improve

* cont : fix
2026-01-17 15:42:42 +02:00
Reese Levine a89002f07b ggml webgpu: support for backend sampling (#18880)
* ggml webgpu: add SOFTPLUS unary operator

Implements SOFTPLUS (log(1 + exp(x))) with f16/f32 support. Uses f32
precision for intermediate calculations to prevent f16 overflow.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support
* Follow Vulkan backend numerical stability pattern

* ggml webgpu: add EXPM1 unary operator

Implements EXPM1 (exp(x) - 1) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add FLOOR unary operator

Implements FLOOR (rounds down to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add CEIL unary operator

Implements CEIL (rounds up to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add ROUND unary operator

Implements ROUND (rounds to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add TRUNC unary operator

Implements TRUNC (truncates towards zero) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* docs : update WebGPU support for unary operators (FLOOR, CEIL, ROUND, TRUNC, EXPM1, SOFTPLUS)

* Updates to webgpu get_memory

* Add argmax

* Add argmax,cumsum,sum,sum_rows

* Add necessary CPY/GET_ROWS operators

* Support for argsort using multi-pass strategy

* Update set_rows for i32 indices, move to pre-wgsl

* Port unary operators to pre-wgsl and support FILL

* Implement PAD

* Add support for top-k

* clean up, scope pipeline init mutex

* fix newline

* Add support for log

* Update LOG for better precision, and ops doc

---------

Co-authored-by: Abhijit Ramesh <abhijitramesh2k@gmail.com>
2026-01-16 16:12:43 -08:00
132 changed files with 13208 additions and 9779 deletions
+4 -1
View File
@@ -89,7 +89,10 @@ nix:
embedding:
- changed-files:
- any-glob-to-any-file: examples/embedding/
jinja parser:
- changed-files:
- any-glob-to-any-file:
- common/jinja/**
Ascend NPU:
- changed-files:
- any-glob-to-any-file:
+6 -6
View File
@@ -16,7 +16,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Get latest Vulkan SDK version
id: vulkan_sdk_version
@@ -24,7 +24,7 @@ jobs:
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
- name: Setup Cache
uses: actions/cache@v4
uses: actions/cache@v5
id: cache-sdk
with:
path: ./vulkan_sdk
@@ -47,10 +47,10 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup Cache
uses: actions/cache@v4
uses: actions/cache@v5
id: cache-toolchain
with:
path: ./spacemit_toolchain
@@ -73,10 +73,10 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup Cache
uses: actions/cache@v4
uses: actions/cache@v5
id: cache-rocm
with:
path: C:\Program Files\AMD\ROCm
+1 -1
View File
@@ -7,7 +7,7 @@ jobs:
linux:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
fetch-depth: 0
+7 -7
View File
@@ -8,7 +8,7 @@ jobs:
# runs-on: ubuntu-24.04
# steps:
# - uses: actions/checkout@v4
# - uses: actions/checkout@v6
# - name: Setup Riscv
# run: |
# sudo dpkg --add-architecture riscv64
@@ -52,7 +52,7 @@ jobs:
# runs-on: ubuntu-24.04
# steps:
# - uses: actions/checkout@v4
# - uses: actions/checkout@v6
# - name: Setup Riscv
# run: |
# sudo dpkg --add-architecture riscv64
@@ -99,7 +99,7 @@ jobs:
# runs-on: ubuntu-24.04
# steps:
# - uses: actions/checkout@v4
# - uses: actions/checkout@v6
# - name: Setup Arm64
# run: |
# sudo dpkg --add-architecture arm64
@@ -146,7 +146,7 @@ jobs:
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Setup LoongArch
run: |
rm -f /etc/apt/sources.list.d/*
@@ -201,7 +201,7 @@ jobs:
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Setup LoongArch
run: |
rm -f /etc/apt/sources.list.d/*
@@ -262,10 +262,10 @@ jobs:
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Use SpacemiT Toolchain Cache
uses: actions/cache@v4
uses: actions/cache@v5
id: cache-toolchain
with:
path: ./spacemit_toolchain
+57 -57
View File
@@ -63,7 +63,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -99,7 +99,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -135,7 +135,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -189,7 +189,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -269,7 +269,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -317,7 +317,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Dependencies
id: depends
@@ -347,7 +347,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
# - name: ccache
# uses: ggml-org/ccache-action@v1.2.16
@@ -380,7 +380,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -414,7 +414,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -436,7 +436,7 @@ jobs:
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
- name: Use Vulkan SDK Cache
uses: actions/cache@v4
uses: actions/cache@v5
id: cache-sdk
with:
path: ./vulkan_sdk
@@ -472,7 +472,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -494,7 +494,7 @@ jobs:
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
- name: Use Vulkan SDK Cache
uses: actions/cache@v4
uses: actions/cache@v5
id: cache-sdk
with:
path: ./vulkan_sdk
@@ -543,7 +543,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -585,7 +585,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Dependencies
id: depends
@@ -616,7 +616,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Dependencies
id: depends
@@ -644,7 +644,7 @@ jobs:
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: add oneAPI to apt
shell: bash
@@ -668,7 +668,7 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -693,7 +693,7 @@ jobs:
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: add oneAPI to apt
shell: bash
@@ -717,7 +717,7 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -749,7 +749,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -781,7 +781,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -813,7 +813,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Build
id: cmake_build
@@ -843,7 +843,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -853,7 +853,7 @@ jobs:
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Download xcframework artifact
uses: actions/download-artifact@v4
uses: actions/download-artifact@v7
with:
name: llama-xcframework
path: build-apple/llama.xcframework/
@@ -885,7 +885,7 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -954,7 +954,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1053,7 +1053,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Install dependencies
env:
@@ -1092,7 +1092,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Install ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1145,7 +1145,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1177,7 +1177,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Grab rocWMMA package
id: grab_rocwmma
@@ -1187,7 +1187,7 @@ jobs:
7z x data.tar
- name: Use ROCm Installation Cache
uses: actions/cache@v4
uses: actions/cache@v5
id: cache-rocm
with:
path: C:\Program Files\AMD\ROCm
@@ -1239,7 +1239,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup Xcode
uses: maxim-lobanov/setup-xcode@v1
@@ -1269,7 +1269,7 @@ jobs:
./build-xcframework.sh
- name: Upload xcframework artifact
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: llama-xcframework
path: build-apple/llama.xcframework/
@@ -1285,7 +1285,7 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v4
uses: actions/checkout@v6
# Disabled due to size (400MB) and always 0 cache hits
# - name: ccache
@@ -1295,7 +1295,7 @@ jobs:
# evict-old-files: 1d
- name: Set up JDK
uses: actions/setup-java@v3
uses: actions/setup-java@v5
with:
java-version: 17
distribution: zulu
@@ -1327,7 +1327,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Install OpenCL Headers and Libs
id: install_opencl
@@ -1402,7 +1402,7 @@ jobs:
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@@ -1460,7 +1460,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1486,7 +1486,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1512,7 +1512,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1538,7 +1538,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1564,7 +1564,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1590,7 +1590,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Test
id: ggml-ci
@@ -1604,7 +1604,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Test
id: ggml-ci
@@ -1618,7 +1618,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Test
id: ggml-ci
@@ -1632,7 +1632,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Test
id: ggml-ci
@@ -1645,7 +1645,7 @@ jobs:
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v4
# uses: actions/checkout@v6
# - name: Test
# id: ggml-ci
@@ -1659,7 +1659,7 @@ jobs:
# steps:
# - name: Clone
# id: checkout
# uses: actions/checkout@v4
# uses: actions/checkout@v6
# - name: Test
# id: ggml-ci
@@ -1673,7 +1673,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Test
id: ggml-ci
@@ -1686,7 +1686,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Dawn Dependency
id: dawn-depends
@@ -1714,7 +1714,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Test
id: ggml-ci
@@ -1728,7 +1728,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1773,7 +1773,7 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Check environment
run: |
@@ -1875,7 +1875,7 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup ccache
run: |
@@ -1969,7 +1969,7 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup ccache
run: |
@@ -2043,7 +2043,7 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup ccache
run: |
@@ -2089,7 +2089,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Dependencies
id: depends
+2 -2
View File
@@ -23,12 +23,12 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: '3.x'
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v5
- uses: actions/stale@v10
with:
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
days-before-issue-stale: 30
+2 -2
View File
@@ -26,7 +26,7 @@ jobs:
# If you do not check out your code, Copilot will do this for you.
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -45,7 +45,7 @@ jobs:
sudo chmod +x /usr/local/bin/git-clang-format
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.11'
+3 -3
View File
@@ -49,7 +49,7 @@ jobs:
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
steps:
- name: Check out the repo
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0 # preserve git history, so we can determine the build number
@@ -63,7 +63,7 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
@@ -208,7 +208,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
+1 -1
View File
@@ -22,7 +22,7 @@ jobs:
editorconfig:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- uses: editorconfig-checker/action-editorconfig-checker@v2
with:
version: v3.0.3
+2 -2
View File
@@ -24,9 +24,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.9.x'
- name: Install dependencies
+2 -2
View File
@@ -9,9 +9,9 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
repository: "ggml-org/llama.cpp"
- uses: actions/labeler@v5
- uses: actions/labeler@v6
with:
configuration-path: '.github/labeler.yml'
+2 -2
View File
@@ -16,10 +16,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.11'
@@ -24,9 +24,9 @@ jobs:
name: check-requirements
steps:
- name: Check out source repository
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Python environment
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Run check-requirements.sh script
+2 -2
View File
@@ -19,9 +19,9 @@ jobs:
name: Lint
steps:
- name: Check out source repository
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Python environment
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: flake8 Lint
+2 -2
View File
@@ -24,9 +24,9 @@ jobs:
name: pyright type-check
steps:
- name: Check out source repository
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Python environment
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Install Python dependencies
+28 -28
View File
@@ -27,7 +27,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@@ -63,7 +63,7 @@ jobs:
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
name: llama-bin-macos-arm64.tar.gz
@@ -74,7 +74,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@@ -111,7 +111,7 @@ jobs:
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
name: llama-bin-macos-x64.tar.gz
@@ -133,7 +133,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@@ -173,7 +173,7 @@ jobs:
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz
name: llama-bin-ubuntu-${{ matrix.build }}.tar.gz
@@ -184,7 +184,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@@ -226,7 +226,7 @@ jobs:
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz
name: llama-bin-ubuntu-vulkan-x64.tar.gz
@@ -242,7 +242,7 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@@ -278,7 +278,7 @@ jobs:
7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-bin-win-cpu-${{ matrix.arch }}.zip
name: llama-bin-win-cpu-${{ matrix.arch }}.zip
@@ -305,7 +305,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -360,7 +360,7 @@ jobs:
7z a -snl llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
@@ -375,7 +375,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Install ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -416,7 +416,7 @@ jobs:
7z a -snl llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
@@ -431,7 +431,7 @@ jobs:
7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\*
- name: Upload Cuda runtime
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
@@ -451,7 +451,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -511,7 +511,7 @@ jobs:
7z a -snl llama-bin-win-sycl-x64.zip ./build/bin/*
- name: Upload the release package
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-bin-win-sycl-x64.zip
name: llama-bin-win-sycl-x64.zip
@@ -531,7 +531,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Grab rocWMMA package
id: grab_rocwmma
@@ -542,7 +542,7 @@ jobs:
- name: Cache ROCm Installation
id: cache-rocm
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: C:\Program Files\AMD\ROCm
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
@@ -617,7 +617,7 @@ jobs:
7z a -snl llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\*
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-bin-win-hip-${{ matrix.name }}-x64.zip
name: llama-bin-win-hip-${{ matrix.name }}-x64.zip
@@ -627,7 +627,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@@ -672,7 +672,7 @@ jobs:
zip -r -y llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
name: llama-${{ steps.tag.outputs.name }}-xcframework.zip
@@ -703,7 +703,7 @@ jobs:
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@@ -763,7 +763,7 @@ jobs:
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz
@@ -794,7 +794,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
@@ -804,7 +804,7 @@ jobs:
- name: Download artifacts
id: download-artifact
uses: actions/download-artifact@v4
uses: actions/download-artifact@v7
with:
path: ./artifact
merge-multiple: true
@@ -887,7 +887,7 @@ jobs:
- name: Upload release
id: upload_release
uses: actions/github-script@v3
uses: actions/github-script@v8
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
@@ -897,7 +897,7 @@ jobs:
for (let file of await fs.readdirSync('./release')) {
if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) {
console.log('uploadReleaseAsset', file);
await github.repos.uploadReleaseAsset({
await github.rest.repos.uploadReleaseAsset({
owner: context.repo.owner,
repo: context.repo.repo,
release_id: release_id,
+5 -5
View File
@@ -37,14 +37,14 @@ jobs:
continue-on-error: true
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Setup Node.js
id: node
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: "22"
cache: "npm"
@@ -131,14 +131,14 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Python setup
id: setup_python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.11'
@@ -148,7 +148,7 @@ jobs:
pip install -r tools/server/tests/requirements.txt
- name: Setup Node.js for WebUI
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: "22"
cache: "npm"
+4 -4
View File
@@ -64,7 +64,7 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
@@ -77,7 +77,7 @@ jobs:
- name: Python setup
id: setup_python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.11'
@@ -100,7 +100,7 @@ jobs:
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
@@ -113,7 +113,7 @@ jobs:
- name: Python setup
id: setup_python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.11'
+2 -2
View File
@@ -18,10 +18,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.x'
+1 -1
View File
@@ -21,7 +21,7 @@ jobs:
- name: Find latest release
id: find_latest_release
uses: actions/github-script@v6
uses: actions/github-script@v8
with:
script: |
const { data: releases } = await github.rest.repos.listReleases({
+1
View File
@@ -15,6 +15,7 @@
/common/common.* @ggerganov
/common/console.* @ggerganov
/common/http.* @angt
/common/jinja/ @ngxson @CISC @aldehir
/common/llguidance.* @ggerganov
/common/log.* @ggerganov
/common/peg-parser.* @aldehir
+1 -1
View File
@@ -254,7 +254,7 @@ function gg_run_ctest_release {
(time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log
if [ -z ${GG_BUILD_LOW_PERF} ]; then
(time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log
(time ctest --output-on-failure -L 'main|python' ) 2>&1 | tee -a $OUT/${ci}-ctest.log
else
(time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
fi
+3 -3
View File
@@ -129,7 +129,7 @@ static void parse_json_tool_calls(
}
}
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
result_.role = "assistant";
@@ -1611,7 +1611,7 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
builder.finish();
}
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE ||
syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE ||
syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
@@ -1635,7 +1635,7 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
return msg;
}
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
if (parser.empty()) {
throw std::runtime_error("Failed to parse due to missing parser definition.");
}
+4 -4
View File
@@ -5,7 +5,7 @@
#include "json-partial.h"
#include "regex-partial.h"
#include <nlohmann/json.hpp>
#include <nlohmann/json_fwd.hpp>
#include <optional>
#include <string>
@@ -19,20 +19,20 @@ class common_chat_msg_partial_exception : public std::runtime_error {
class common_chat_msg_parser {
std::string input_;
bool is_partial_;
common_chat_syntax syntax_;
common_chat_parser_params syntax_; // TODO: rename to params
std::string healing_marker_;
size_t pos_ = 0;
common_chat_msg result_;
public:
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
const std::string & input() const { return input_; }
size_t pos() const { return pos_; }
const std::string & healing_marker() const { return healing_marker_; }
const bool & is_partial() const { return is_partial_; }
const common_chat_msg & result() const { return result_; }
const common_chat_syntax & syntax() const { return syntax_; }
const common_chat_parser_params & syntax() const { return syntax_; }
void move_to(size_t pos) {
if (pos > input_.size()) {
+7 -7
View File
@@ -601,18 +601,18 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp
return tmpls->has_explicit_template;
}
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
if (variant != nullptr) {
if (strcmp(variant, "tool_use") == 0) {
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) {
if (!variant.empty()) {
if (variant == "tool_use") {
if (tmpls->template_tool_use) {
return tmpls->template_tool_use->source().c_str();
return tmpls->template_tool_use->source();
}
return nullptr;
return "";
} else {
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str());
}
}
return tmpls->template_default->source().c_str();
return tmpls->template_default->source();
}
common_chat_templates_ptr common_chat_templates_init(
+17 -8
View File
@@ -145,7 +145,7 @@ struct common_chat_templates_inputs {
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;
@@ -165,14 +165,21 @@ struct common_chat_params {
std::string parser;
};
struct common_chat_syntax {
// per-message parsing syntax
// should be derived from common_chat_params
struct common_chat_parser_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool parse_tool_calls = true;
common_peg_arena parser = {};
common_chat_parser_params() = default;
common_chat_parser_params(const common_chat_params & chat_params) {
format = chat_params.format;
thinking_forced_open = chat_params.thinking_forced_open;
}
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
@@ -191,7 +198,7 @@ common_chat_templates_ptr common_chat_templates_init(
const std::string & eos_token_override = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
struct common_chat_params common_chat_templates_apply(
@@ -213,10 +220,12 @@ std::string common_chat_format_example(
const std::map<std::string, std::string> & chat_template_kwargs);
const char* common_chat_format_name(common_chat_format format);
const char* common_reasoning_format_name(common_reasoning_format format);
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
// used by arg and server
const char * common_reasoning_format_name(common_reasoning_format format);
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
+3
View File
@@ -57,6 +57,8 @@ extern const char * LLAMA_COMMIT;
extern const char * LLAMA_COMPILER;
extern const char * LLAMA_BUILD_TARGET;
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
struct common_control_vector_load_info;
//
@@ -284,6 +286,7 @@ struct common_params_diffusion {
};
// reasoning API response format (not to be confused as chat template's reasoning format)
// only used by server
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
+19 -14
View File
@@ -314,23 +314,26 @@ static bool common_pull_file(httplib::Client & cli,
// download one single file from remote URL to local path
// returns status code or -1 on error
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token,
const common_header_list & custom_headers) {
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token,
const common_header_list & custom_headers) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
auto [cli, parts] = common_http_client(url);
httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
if (!bearer_token.empty()) {
default_headers.insert({"Authorization", "Bearer " + bearer_token});
}
httplib::Headers headers;
for (const auto & h : custom_headers) {
default_headers.emplace(h.first, h.second);
headers.emplace(h.first, h.second);
}
cli.set_default_headers(default_headers);
if (headers.find("User-Agent") == headers.end()) {
headers.emplace("User-Agent", "llama-cpp/" + build_info);
}
if (!bearer_token.empty()) {
headers.emplace("Authorization", "Bearer " + bearer_token);
}
cli.set_default_headers(headers);
const bool file_exists = std::filesystem::exists(path);
@@ -437,10 +440,12 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
const common_remote_params & params) {
auto [cli, parts] = common_http_client(url);
httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
for (const auto & header : params.headers) {
headers.emplace(header.first, header.second);
httplib::Headers headers;
for (const auto & h : params.headers) {
headers.emplace(h.first, h.second);
}
if (headers.find("User-Agent") == headers.end()) {
headers.emplace("User-Agent", "llama-cpp/" + build_info);
}
if (params.timeout > 0) {
+11
View File
@@ -57,6 +57,17 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
throw std::runtime_error("error: invalid URL format");
}
#ifndef CPPHTTPLIB_OPENSSL_SUPPORT
if (parts.scheme == "https") {
throw std::runtime_error(
"HTTPS is not supported. Please rebuild with:\n"
" -DLLAMA_BUILD_BORINGSSL=ON\n"
" -DLLAMA_BUILD_LIBRESSL=ON\n"
"or ensure dev files of an OpenSSL-compatible library are available when building."
);
}
#endif
httplib::Client cli(parts.scheme + "://" + parts.host);
if (!parts.user.empty()) {
+12 -7
View File
@@ -91,6 +91,16 @@ lexer_result lexer::tokenize(const std::string & source) {
return str;
};
auto consume_numeric = [&]() -> std::string {
std::string num = consume_while(is_integer);
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
++pos; // Consume '.'
std::string frac = consume_while(is_integer);
num += "." + frac;
}
return num;
};
auto next_pos_is = [&](std::initializer_list<char> chars, size_t n = 1) -> bool {
if (pos + n >= src.size()) return false;
for (char c : chars) {
@@ -258,7 +268,7 @@ lexer_result lexer::tokenize(const std::string & source) {
++pos; // Consume the operator
// Check for numbers following the unary operator
std::string num = consume_while(is_integer);
std::string num = consume_numeric();
std::string value = std::string(1, ch) + num;
token::type t = num.empty() ? token::unary_operator : token::numeric_literal;
// JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str());
@@ -307,12 +317,7 @@ lexer_result lexer::tokenize(const std::string & source) {
// Numbers
if (is_integer(ch)) {
start_pos = pos;
std::string num = consume_while(is_integer);
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
++pos; // Consume '.'
std::string frac = consume_while(is_integer);
num += "." + frac;
}
std::string num = consume_numeric();
// JJ_DEBUG("consumed numeric literal: '%s'", num.c_str());
tokens.push_back({token::numeric_literal, num, start_pos});
continue;
+23 -11
View File
@@ -268,8 +268,7 @@ value binary_expression::execute_impl(context & ctx) {
// String in object
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
auto key = left_val->as_string().str();
auto & obj = right_val->as_object();
bool has_key = obj.find(key) != obj.end();
bool has_key = right_val->has_key(key);
if (op.value == "in") {
return mk_val<value_bool>(has_key);
} else if (op.value == "not in") {
@@ -464,7 +463,7 @@ value for_statement::execute_impl(context & ctx) {
std::vector<value> items;
if (is_val<value_object>(iterable_val)) {
JJ_DEBUG("%s", "For loop over object keys");
auto & obj = iterable_val->as_object();
auto & obj = iterable_val->as_ordered_object();
for (auto & p : obj) {
auto tuple = mk_val<value_array>();
if (iterable_val->val_obj.is_key_numeric) {
@@ -560,6 +559,7 @@ value for_statement::execute_impl(context & ctx) {
for (size_t i = 0; i < filtered_items.size(); i++) {
JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size());
value_object loop_obj = mk_val<value_object>();
loop_obj->has_builtins = false; // loop object has no builtins
loop_obj->insert("index", mk_val<value_int>(i + 1));
loop_obj->insert("index0", mk_val<value_int>(i));
loop_obj->insert("revindex", mk_val<value_int>(filtered_items.size() - i));
@@ -717,6 +717,7 @@ value member_expression::execute_impl(context & ctx) {
value property;
if (this->computed) {
// syntax: obj[expr]
JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str());
int64_t arr_size = 0;
@@ -745,10 +746,24 @@ value member_expression::execute_impl(context & ctx) {
property = this->property->execute(ctx);
}
} else {
// syntax: obj.prop
if (!is_stmt<identifier>(this->property)) {
throw std::runtime_error("Non-computed member property must be an identifier");
throw std::runtime_error("Static member property must be an identifier");
}
property = mk_val<value_string>(cast_stmt<identifier>(this->property)->val);
std::string prop = property->as_string().str();
JJ_DEBUG("Member expression, object type %s, static property '%s'", object->type().c_str(), prop.c_str());
// behavior of jinja2: obj having prop as a built-in function AND 'prop', as an object key,
// then obj.prop returns the built-in function, not the property value.
// while obj['prop'] returns the property value.
// example: {"obj": {"items": 123}} -> obj.items is the built-in function, obj['items'] is 123
value val = try_builtin_func(ctx, prop, object, true);
if (!is_val<value_undefined>(val)) {
return val;
}
// else, fallthrough to normal property access below
}
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
@@ -763,11 +778,8 @@ value member_expression::execute_impl(context & ctx) {
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
}
auto key = property->as_string().str();
auto & obj = object->as_object();
auto it = obj.find(key);
if (it != obj.end()) {
val = it->second;
} else {
val = object->at(key, val);
if (is_val<value_undefined>(val)) {
val = try_builtin_func(ctx, key, object, true);
}
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
@@ -793,7 +805,7 @@ value member_expression::execute_impl(context & ctx) {
} else if (is_val<value_string>(property)) {
auto key = property->as_string().str();
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
val = try_builtin_func(ctx, key, object);
val = try_builtin_func(ctx, key, object, true);
} else {
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
}
@@ -802,7 +814,7 @@ value member_expression::execute_impl(context & ctx) {
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
}
auto key = property->as_string().str();
val = try_builtin_func(ctx, key, object);
val = try_builtin_func(ctx, key, object, true);
}
if (ctx.is_get_stats && val && object && property) {
+3 -2
View File
@@ -56,6 +56,7 @@ struct context {
// src is optional, used for error reporting
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
env = mk_val<value_object>();
env->has_builtins = false; // context object has no builtins
env->insert("true", mk_val<value_bool>(true));
env->insert("True", mk_val<value_bool>(true));
env->insert("false", mk_val<value_bool>(false));
@@ -68,7 +69,7 @@ struct context {
context(const context & parent) : context() {
// inherit variables (for example, when entering a new scope)
auto & pvar = parent.env->as_object();
auto & pvar = parent.env->as_ordered_object();
for (const auto & pair : pvar) {
set_val(pair.first, pair.second);
}
@@ -265,7 +266,7 @@ struct comment_statement : public statement {
struct member_expression : public expression {
statement_ptr object;
statement_ptr property;
bool computed;
bool computed; // true if obj[expr] and false if obj.prop
member_expression(statement_ptr && object, statement_ptr && property, bool computed)
: object(std::move(object)), property(std::move(property)), computed(computed) {
+74 -54
View File
@@ -698,6 +698,7 @@ const func_builtins & value_bool_t::get_builtins() const {
bool val = args.get_pos(0)->as_bool();
return mk_val<value_string>(val ? "True" : "False");
}},
{"tojson", tojson},
};
return builtins;
}
@@ -775,19 +776,30 @@ const func_builtins & value_array_t::get_builtins() const {
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("join() first argument must be an array");
}
value val_delim = args.get_kwarg_or_pos("d", 1);
value val_attribute = args.get_kwarg_or_pos("attribute", 2);
if (!val_attribute->is_undefined()) {
throw not_implemented_exception("array attribute join not implemented");
}
value val_delim = args.get_kwarg_or_pos("d", 1);
value attribute = args.get_kwarg_or_pos("attribute", 2);
const auto & arr = args.get_pos(0)->as_array();
std::string delim = is_val<value_string>(val_delim) ? val_delim->as_string().str() : "";
const bool attr_is_int = is_val<value_int>(attribute);
if (!attribute->is_undefined() && !is_val<value_string>(attribute) && !attr_is_int) {
throw raised_exception("join() attribute must be string or integer");
}
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::string result;
for (size_t i = 0; i < arr.size(); ++i) {
if (!is_val<value_string>(arr[i]) && !is_val<value_int>(arr[i]) && !is_val<value_float>(arr[i])) {
value val_arr = arr[i];
if (!attribute->is_undefined()) {
if (attr_is_int && is_val<value_array>(val_arr)) {
val_arr = val_arr->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
val_arr = val_arr->at(attr_name);
}
}
if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
throw raised_exception("join() can only join arrays of strings or numerics");
}
result += arr[i]->as_string().str();
result += val_arr->as_string().str();
if (i < arr.size() - 1) {
result += delim;
}
@@ -802,26 +814,30 @@ const func_builtins & value_array_t::get_builtins() const {
}},
{"tojson", tojson},
{"map", [](const func_args & args) -> value {
args.ensure_count(2, 3);
args.ensure_count(2);
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("map: first argument must be an array");
}
value attribute = args.get_kwarg_or_pos("attribute", 1);
if (is_val<value_int>(attribute)) {
throw not_implemented_exception("map: integer attribute not implemented");
if (!is_val<value_kwarg>(args.get_args().at(1))) {
throw not_implemented_exception("map: filter-mapping not implemented");
}
if (!is_val<value_string>(attribute)) {
value attribute = args.get_kwarg_or_pos("attribute", 1);
const bool attr_is_int = is_val<value_int>(attribute);
if (!is_val<value_string>(attribute) && !attr_is_int) {
throw raised_exception("map: attribute must be string or integer");
}
std::string attr_name = attribute->as_string().str();
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->as_string().str();
value default_val = args.get_kwarg("default", mk_val<value_undefined>());
auto out = mk_val<value_array>();
auto arr = args.get_pos(0)->as_array();
for (const auto & item : arr) {
if (!is_val<value_object>(item)) {
throw raised_exception("map: item is not an object");
value attr_val;
if (attr_is_int) {
attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
} else {
attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
}
value attr_val = item->at(attr_name, default_val);
out->push_back(attr_val);
}
return out;
@@ -847,29 +863,35 @@ const func_builtins & value_array_t::get_builtins() const {
return arr_editable->pop_at(index);
}},
{"sort", [](const func_args & args) -> value {
args.ensure_count(1, 3);
args.ensure_count(1, 4);
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("sort: first argument must be an array");
}
bool reverse = args.get_kwarg("reverse", mk_val<value_undefined>())->as_bool();
value attribute = args.get_kwarg("attribute", mk_val<value_undefined>());
std::string attr = attribute->is_undefined() ? "" : attribute->as_string().str();
value val_reverse = args.get_kwarg_or_pos("reverse", 1);
value val_case = args.get_kwarg_or_pos("case_sensitive", 2);
value attribute = args.get_kwarg_or_pos("attribute", 3);
// FIXME: sorting is currently always case sensitive
//const bool case_sensitive = val_case->as_bool(); // undefined == false
const bool reverse = val_reverse->as_bool(); // undefined == false
const bool attr_is_int = is_val<value_int>(attribute);
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
value val_a = a;
value val_b = b;
if (!attribute->is_undefined()) {
if (!is_val<value_object>(a) || !is_val<value_object>(b)) {
throw raised_exception("sort: items are not objects");
if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
val_a = a->at(attr_int);
val_b = b->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
val_a = a->at(attr_name);
val_b = b->at(attr_name);
} else {
throw raised_exception("sort: unsupported object attribute comparison");
}
val_a = attr.empty() ? a : a->at(attr);
val_b = attr.empty() ? b : b->at(attr);
}
if (reverse) {
return value_compare(val_a, val_b, value_compare_op::gt);
} else {
return !value_compare(val_a, val_b, value_compare_op::gt);
}
return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
});
return mk_val<value_array>(arr);
}},
@@ -888,6 +910,11 @@ const func_builtins & value_array_t::get_builtins() const {
const func_builtins & value_object_t::get_builtins() const {
if (!has_builtins) {
static const func_builtins no_builtins = {};
return no_builtins;
}
static const func_builtins builtins = {
// {"default", default_value}, // cause issue with gpt-oss
{"get", [](const func_args & args) -> value {
@@ -902,18 +929,13 @@ const func_builtins & value_object_t::get_builtins() const {
if (args.count() == 3) {
default_val = args.get_pos(2);
}
const auto & obj = args.get_pos(0)->as_object();
const value obj = args.get_pos(0);
std::string key = args.get_pos(1)->as_string().str();
auto it = obj.find(key);
if (it != obj.end()) {
return it->second;
} else {
return default_val;
}
return obj->at(key, default_val);
}},
{"keys", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_object();
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
result->push_back(mk_val<value_string>(pair.first));
@@ -922,7 +944,7 @@ const func_builtins & value_object_t::get_builtins() const {
}},
{"values", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_object();
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
result->push_back(pair.second);
@@ -931,7 +953,7 @@ const func_builtins & value_object_t::get_builtins() const {
}},
{"items", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_object();
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
auto item = mk_val<value_array>();
@@ -945,7 +967,7 @@ const func_builtins & value_object_t::get_builtins() const {
{"string", tojson},
{"length", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_object();
const auto & obj = args.get_pos(0)->as_ordered_object();
return mk_val<value_int>(static_cast<int64_t>(obj.size()));
}},
{"tojson", [](const func_args & args) -> value {
@@ -958,21 +980,18 @@ const func_builtins & value_object_t::get_builtins() const {
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
value val_by = args.get_kwarg_or_pos("by", 2);
value val_reverse = args.get_kwarg_or_pos("reverse", 3);
// FIXME: sorting is case sensitive
// FIXME: sorting is currently always case sensitive
//const bool case_sensitive = val_case->as_bool(); // undefined == false
const bool reverse = val_reverse->as_bool(); // undefined == false
if (!val_by->is_undefined()) {
throw not_implemented_exception("dictsort by key not implemented");
}
if (reverse) {
throw not_implemented_exception("dictsort reverse not implemented");
}
value_t::map obj = val_input->val_obj; // copy
std::sort(obj.ordered.begin(), obj.ordered.end(), [&](const auto & a, const auto & b) {
return a.first < b.first;
const bool by_value = is_val<value_string>(val_by) && val_by->as_string().str() == "value" ? true : false;
auto result = mk_val<value_object>(val_input); // copy
std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) {
if (by_value) {
return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt);
} else {
return reverse ? a.first > b.first : a.first < b.first;
}
});
auto result = mk_val<value_object>();
result->val_obj = std::move(obj);
return result;
}},
{"join", [](const func_args &) -> value {
@@ -986,6 +1005,7 @@ const func_builtins & value_none_t::get_builtins() const {
static const func_builtins builtins = {
{"default", default_value},
{"tojson", tojson},
{"string", [](const func_args &) -> value { return mk_val<value_string>("None"); }}
};
return builtins;
}
@@ -1169,7 +1189,7 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val,
}
oss << "]";
} else if (is_val<value_object>(val)) {
const auto & obj = val->val_obj.ordered; // IMPORTANT: need to keep exact order
const auto & obj = val->as_ordered_object(); // IMPORTANT: need to keep exact order
oss << "{";
if (!obj.empty()) {
oss << newline();
+32 -5
View File
@@ -146,7 +146,7 @@ struct value_t {
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
virtual const std::map<std::string, value> & as_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
virtual bool is_none() const { return false; }
virtual bool is_undefined() const { return false; }
@@ -154,6 +154,9 @@ struct value_t {
throw std::runtime_error("No builtins available for type " + type());
}
virtual bool has_key(const std::string & key) {
return val_obj.unordered.find(key) != val_obj.unordered.end();
}
virtual value & at(const std::string & key, value & default_val) {
auto it = val_obj.unordered.find(key);
if (it == val_obj.unordered.end()) {
@@ -168,8 +171,20 @@ struct value_t {
}
return val_obj.unordered.at(key);
}
virtual value & at(size_t index) {
if (index >= val_arr.size()) {
virtual value & at(int64_t index, value & default_val) {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
return default_val;
}
return val_arr[index];
}
virtual value & at(int64_t index) {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
}
return val_arr[index];
@@ -188,6 +203,9 @@ struct value_int_t : public value_t {
virtual int64_t as_int() const override { return val_int; }
virtual double as_float() const override { return static_cast<double>(val_int); }
virtual string as_string() const override { return std::to_string(val_int); }
virtual bool as_bool() const override {
return val_int != 0;
}
virtual const func_builtins & get_builtins() const override;
};
using value_int = std::shared_ptr<value_int_t>;
@@ -204,6 +222,9 @@ struct value_float_t : public value_t {
if (out.back() == '.') out.push_back('0'); // leave one zero if no decimals
return out;
}
virtual bool as_bool() const override {
return val_flt != 0.0;
}
virtual const func_builtins & get_builtins() const override;
};
using value_float = std::shared_ptr<value_float_t>;
@@ -286,6 +307,7 @@ using value_array = std::shared_ptr<value_array_t>;
struct value_object_t : public value_t {
bool has_builtins = true; // context and loop objects do not have builtins
value_object_t() = default;
value_object_t(value & v) {
val_obj = v->val_obj;
@@ -295,11 +317,16 @@ struct value_object_t : public value_t {
val_obj.insert(pair.first, pair.second);
}
}
value_object_t(const std::vector<std::pair<std::string, value>> & obj) {
for (const auto & pair : obj) {
val_obj.insert(pair.first, pair.second);
}
}
void insert(const std::string & key, const value & val) {
val_obj.insert(key, val);
}
virtual std::string type() const override { return "Object"; }
virtual const std::map<std::string, value> & as_object() const override { return val_obj.unordered; }
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const override { return val_obj.ordered; }
virtual bool as_bool() const override {
return !val_obj.unordered.empty();
}
@@ -315,12 +342,12 @@ struct value_none_t : public value_t {
virtual std::string type() const override { return "None"; }
virtual bool is_none() const override { return true; }
virtual bool as_bool() const override { return false; }
virtual string as_string() const override { return string("None"); }
virtual std::string as_repr() const override { return type(); }
virtual const func_builtins & get_builtins() const override;
};
using value_none = std::shared_ptr<value_none_t>;
struct value_undefined_t : public value_t {
std::string hint; // for debugging, to indicate where undefined came from
value_undefined_t(const std::string & h = "") : hint(h) {}
+1
View File
@@ -1,5 +1,6 @@
#pragma once
// TODO: use json_fwd.hpp when possible
#include <nlohmann/json.hpp>
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
+35 -3
View File
@@ -1078,6 +1078,9 @@ class TextModel(ModelBase):
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
# ref: https://huggingface.co/aari1995/German_Semantic_V3
res = "jina-v2-de"
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
res = "glm4"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@@ -2976,7 +2979,10 @@ class Llama4VisionModel(MmprojModel):
return []
@ModelBase.register("Mistral3ForConditionalGeneration")
@ModelBase.register(
"Mistral3ForConditionalGeneration",
"Ministral3ForCausalLM",
)
class Mistral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.MISTRAL3
@@ -7458,7 +7464,7 @@ class DeepseekModel(TextModel):
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
"YoutuForCausalLM",
"YoutuVLForConditionalGeneration"
"YoutuVLForConditionalGeneration",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@@ -8446,6 +8452,32 @@ class Glm4MoeModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("Glm4MoeLiteForCausalLM")
class Glm4MoeLiteModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
# copied from Glm4MoeModel
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
special_vocab.add_to_gguf(self.gguf_writer)
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(TextModel):
model_arch = gguf.MODEL_ARCH.CHATGLM
@@ -9183,7 +9215,7 @@ class NemotronHModel(GraniteHybridModel):
return [(mapped_name, reshaped_data)]
if name.endswith("mixer.norm.weight"):
reshaped_data = data_torch.reshape(8, 512)
reshaped_data = data_torch.reshape(self.n_group, -1)
mapped_name = self.map_tensor_name(name)
return [(mapped_name, reshaped_data)]
+1
View File
@@ -170,6 +170,7 @@ pre_computed_hashes = [
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
]
+1
View File
@@ -8,6 +8,7 @@
- [CMake Options](#cmake-options)
- [Android](#android)
- [Windows 11 Arm64](#windows-11-arm64)
- [Linux](#Linux)
- [Known Issue](#known-issues)
- [TODO](#todo)
+17 -16
View File
@@ -20,10 +20,10 @@ Legend:
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
@@ -36,17 +36,17 @@ Legend:
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | 🟡 | | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -63,7 +63,7 @@ Legend:
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ | | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ | | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
@@ -73,8 +73,9 @@ Legend:
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -85,7 +86,7 @@ Legend:
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
@@ -96,7 +97,7 @@ Legend:
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
@@ -106,14 +107,14 @@ Legend:
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
+7683 -7584
View File
File diff suppressed because it is too large Load Diff
@@ -4,6 +4,7 @@ set -e
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
BUILD_DIR="${2:-"$BUILD_DIR"}"
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
@@ -13,6 +14,10 @@ if [ -z "$CONVERTED_MODEL" ]; then
exit 1
fi
cmake --build ../../build --target llama-debug -j8
if [ -z "$BUILD_DIR" ]; then
BUILD_DIR="../../build"
fi
../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
cmake --build ${BUILD_DIR} --target llama-debug -j8
${BUILD_DIR}/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
@@ -5,11 +5,16 @@ set -e
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
BUILD_DIR="${3:-"$BUILD_DIR"}"
if [ -z "$MODEL_TESTING_PROMPT"]; then
if [ -z "$MODEL_TESTING_PROMPT" ]; then
MODEL_TESTING_PROMPT="Hello, my name is"
fi
if [ -z "$BUILD_DIR" ]; then
BUILD_DIR="../../build"
fi
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
echo "Error: Model path must be provided either as:" >&2
@@ -21,6 +26,6 @@ fi
echo $CONVERTED_MODEL
echo $MODEL_TESTING_PROMPT
cmake --build ../../build --target llama-debug -j8
cmake --build ${BUILD_DIR} --target llama-debug -j8
../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
@@ -28,6 +28,7 @@ done
# First try command line argument, then environment variable
CONVERTED_MODEL="${CONVERTED_MODEL:-"$CONVERTED_EMBEDDING_MODEL"}"
BUILD_DIR="${BUILD_DIR:-"../../build"}"
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
@@ -50,5 +51,5 @@ fi
echo $CONVERTED_MODEL
cmake --build ../../build --target llama-debug -j8
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE
cmake --build ${BUILD_DIR} --target llama-debug -j8
${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE
+39 -7
View File
@@ -630,10 +630,11 @@ extern "C" {
// this tensor...
enum ggml_tensor_flag {
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed
};
enum ggml_tri_type {
@@ -2577,11 +2578,42 @@ extern "C" {
struct ggml_tensor * grad,
struct ggml_tensor * sgd_params); // alpha, weight decay
// build forward mutiple tensors and select one of them for computing
// this is useful for creating graphs that have constant topology but compute different things based on the input
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
//
// automatic differentiation
// nodes:
// | - build forward into the graph but do not compute
// c - build forward into the graph and compute
//
// | | ... c ... |
// | | ... c ... |
// | | ... c ... |
// [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx)
// c
// c
//
// example:
// struct ggml_tensor * curs[3];
//
// curs[0] = compute0(...);
// curs[1] = compute1(...);
// curs[2] = compute2(...);
//
// int idx = select_branch(some_input);
//
// struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx);
//
GGML_API struct ggml_tensor * ggml_build_forward_select(
struct ggml_cgraph * cgraph,
struct ggml_tensor ** tensors,
int n_tensors,
int idx);
GGML_API void ggml_build_forward_expand(
struct ggml_cgraph * cgraph,
struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(
struct ggml_context * ctx, // context for gradient computation
struct ggml_cgraph * cgraph,
@@ -2613,7 +2645,7 @@ extern "C" {
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
// dump the graph into a file using the dot format
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename);
// TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
+4 -20
View File
@@ -77,39 +77,23 @@
#include "ggml-zendnn.h"
#endif
// disable C++17 deprecation warning for std::codecvt_utf8
#if defined(__clang__)
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
namespace fs = std::filesystem;
static std::string path_str(const fs::path & path) {
std::string u8path;
try {
#if defined(__cpp_lib_char8_t)
// C++20 and later: u8string() returns std::u8string
std::u8string u8str = path.u8string();
u8path = std::string(reinterpret_cast<const char*>(u8str.c_str()));
const std::u8string u8str = path.u8string();
return std::string(reinterpret_cast<const char *>(u8str.data()), u8str.size());
#else
// C++17: u8string() returns std::string
u8path = path.u8string();
return path.u8string();
#endif
} catch (...) {
return std::string();
}
return u8path;
}
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
+3 -2
View File
@@ -874,9 +874,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
}
if (sched->debug > 1) {
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name,
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
@@ -1922,6 +1922,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
dst->view_offs = src->view_offs;
}
dst->op = src->op;
dst->flags = src->flags;
memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
ggml_set_name(dst, src->name);
+4
View File
@@ -226,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
switch (node->op) {
case GGML_OP_MUL_MAT:
ggml_backend_blas_mul_mat(ctx, node);
+4
View File
@@ -2146,6 +2146,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+4
View File
@@ -2943,6 +2943,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
ggml_compute_forward(&params, node);
if (state->ith == 0 && cplan->abort_callback &&
+19 -10
View File
@@ -2,6 +2,9 @@
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
# define STRIDED_ITERATOR_AVAILABLE
# endif
using namespace cub;
#endif // GGML_CUDA_USE_CUB
@@ -14,12 +17,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
}
}
#ifndef STRIDED_ITERATOR_AVAILABLE
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx <= nrows) {
offsets[idx] = idx * ncols;
}
}
#endif // STRIDED_ITERATOR_AVAILABLE
#ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
@@ -31,19 +36,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * temp_indices = temp_indices_alloc.get();
float * temp_keys = temp_keys_alloc.get();
int * d_offsets = offsets_alloc.get();
static const int block_size = 256;
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
#ifdef STRIDED_ITERATOR_AVAILABLE
auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
#else
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * offset_iterator = offsets_alloc.get();
const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
#endif
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
size_t temp_storage_bytes = 0;
@@ -57,7 +65,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, stream);
offset_iterator, offset_iterator + 1, stream);
}
} else {
if (nrows == 1) {
@@ -66,7 +74,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
stream);
}
}
@@ -80,7 +89,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
}
} else {
if (nrows == 1) {
@@ -89,8 +98,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
stream);
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
offset_iterator + 1, stream);
}
}
}
+1
View File
@@ -1123,6 +1123,7 @@ struct ggml_tensor_extra_gpu {
struct ggml_cuda_graph_node_properties {
void * node_address;
ggml_op node_op;
int32_t flags;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
+23 -8
View File
@@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = get_cols_per_thread();
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -510,7 +510,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
} else {
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -522,14 +521,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
// Wide version of KQ_C is column-major
if constexpr (cols_per_warp == 8) {
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
} else {
// Wide version of KQ_C is column-major
#if defined(AMD_WMMA_AVAILABLE)
// RDNA matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
// RDNA matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
#else
// swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
// swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
}
@@ -953,7 +956,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = get_cols_per_thread();
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -1484,6 +1487,13 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
#ifdef VOLTA_MMA_AVAILABLE
if (ncols1*ncols2 < 32) {
NO_DEVICE_CODE;
return;
}
#endif // VOLTA_MMA_AVAILABLE
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
@@ -1728,3 +1738,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
// For GLM 4.7 Flash
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
+12
View File
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
@@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 4 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
return;
}
}
if constexpr (DV <= 256) {
+7 -3
View File
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
GGML_ASSERT(gqa_ratio % 4 == 0);
if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
}
} break;
default:
GGML_ABORT("fatal error");
@@ -262,7 +266,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
+8
View File
@@ -2918,6 +2918,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
props->node_address = node->data;
props->node_op = node->op;
props->flags = node->flags;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
props->ne[i] = node->ne[i];
props->nb[i] = node->nb[i];
@@ -2961,6 +2962,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
return false;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
return false;
}
return true;
}
@@ -3378,6 +3383,9 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
// start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
@@ -85,7 +85,7 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq != 576 and ncols2 == 16:
continue
if head_size_kq == 576 and ncols2 != 16:
if head_size_kq == 576 and ncols2 not in (4, 16):
continue
head_size_v = head_size_kq if head_size_kq != 576 else 512
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
-1
View File
@@ -4,7 +4,6 @@
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
# include <cuda/iterator>
# define CUB_TOP_K_AVAILABLE
using namespace cub;
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
+4
View File
@@ -2497,6 +2497,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
uint32_t flags = 0;
// skip quantizer if src1 is reused
+3
View File
@@ -611,6 +611,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in
if (node->op != ops[i]) {
return false;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return false;
}
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
return false;
}
+2 -6
View File
@@ -1078,12 +1078,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
op->src[0]->ne[0] != 256) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek sizes
// TODO: disabled for now, until optmized
op->src[0]->ne[0] != 256 &&
op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {
+5 -1
View File
@@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_ABORT("unsupported op");
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return 1;
}
int n_fuse = 1;
// check if the current node can run concurrently with other nodes before it
@@ -2516,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
int32_t nsg = 4;
int32_t nsg = ne00 >= 512 ? 8 : 4;
const size_t smem = FATTN_SMEM(nsg);
+8 -5
View File
@@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl(
constexpr short NC = (C/8)/NSG;
// note: do not unroll for large heads
#pragma unroll (DK <= 64 ? NC : 1)
for (short cc = 0; cc < NC; ++cc) {
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
if (DK % 16 != 0) {
@@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl(
k8x8_t mk[2];
q8x8_t mq[2];
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
// note: too much unroll can tank the performance for large heads
#pragma unroll (MIN(DK8/2, 4*NSG))
for (short i = 0; i < DK8/2; ++i) {
simdgroup_barrier(mem_flags::mem_none);
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
@@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl(
pv += 8*NS20;
}
} else {
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
constexpr short NC = (C/8)/2;
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
s8x8_t vs[2];
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
@@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS
+1
View File
@@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS
add
add_id
argsort
tri
fill
clamp
cpy
+235 -13
View File
@@ -398,6 +398,7 @@ struct ggml_backend_opencl_context {
int adreno_wave_size;
cl_bool non_uniform_workgroups;
size_t image_max_buffer_size;
cl_context context;
cl_command_queue queue;
@@ -407,6 +408,10 @@ struct ggml_backend_opencl_context {
ggml_cl_buffer prealloc_scales_trans;
ggml_cl_buffer prealloc_act_trans;
// prealloc buffers for src0 and src1
ggml_cl_buffer prealloc_src0;
ggml_cl_buffer prealloc_src1;
cl_program program_add;
cl_program program_add_id;
cl_program program_clamp;
@@ -489,6 +494,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
cl_kernel kernel_relu;
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
cl_kernel kernel_tri;
cl_kernel kernel_fill;
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
@@ -793,6 +799,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// tri
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "tri.cl.h"
};
#else
const std::string kernel_src = read_file("tri.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err));
GGML_LOG_CONT(".");
CL_CHECK(clReleaseProgram(prog));
}
// fill
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2639,6 +2663,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL);
GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size);
clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);
@@ -3058,6 +3085,10 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
@@ -3201,6 +3232,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
default:
return false;
}
case GGML_OP_TRI:
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
case GGML_OP_FILL:
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
case GGML_OP_CLAMP:
@@ -4686,6 +4719,81 @@ static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct gg
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
}
// Copy a noncontiguous tensor to contiguous tensor. ne[] remains the same but
// nb[] is recalculated such that tensor is contiguous.
static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor * src, cl_mem dst,
cl_ulong &nb0, cl_ulong &nb1, cl_ulong &nb2, cl_ulong &nb3) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
const int tensor_type_size = ggml_type_size(src->type);
const int ne00 = src->ne[0];
const int ne01 = src->ne[1];
const int ne02 = src->ne[2];
const int ne03 = src->ne[3];
const cl_ulong nb00 = src->nb[0];
const cl_ulong nb01 = src->nb[1];
const cl_ulong nb02 = src->nb[2];
const cl_ulong nb03 = src->nb[3];
const int ne0 = src->ne[0];
const int ne1 = src->ne[1];
const int ne2 = src->ne[2];
const int ne3 = src->ne[3];
nb0 = tensor_type_size;
nb1 = tensor_type_size*ne00;
nb2 = tensor_type_size*ne00*ne01;
nb3 = tensor_type_size*ne00*ne01*ne02;
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra;
cl_ulong offset0 = extra->offset + src->view_offs;
cl_ulong offsetd = 0;
cl_kernel kernel;
switch (src->type) {
case GGML_TYPE_F32:
kernel = backend_ctx->kernel_cpy_f32_f32;
break;
case GGML_TYPE_F16:
kernel = backend_ctx->kernel_cpy_f16_f16;
break;
default:
GGML_ASSERT(false && "not implemented");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &dst));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
const int nth = MIN(64, ne00);
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src);
}
static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
UNUSED(backend);
UNUSED(src0);
@@ -5961,6 +6069,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const int tri_type = ggml_get_op_params_i32(dst, 0);
const int64_t n = ggml_nelements(dst);
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
cl_kernel kernel = backend_ctx->kernel_tri;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &tri_type));
size_t local_work_size[1] = { 256 };
size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);
}
static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
@@ -7661,9 +7807,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
cl_context context = backend_ctx->context;
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0 &&
// dst is wrapped with image1d_buffer, the size limit applies, also src0
(ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4 <= backend_ctx->image_max_buffer_size)) {
// For KQ
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
((nb01 * ne01 / 4)/4 <= backend_ctx->image_max_buffer_size) &&
nb00 <= nb02 &&
nb02 <= nb01 &&
nb01 <= nb03 &&
@@ -7674,7 +7823,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
return;
}
// For KQV
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
((nb02 * ne02 / 4)/4 <= backend_ctx->image_max_buffer_size)) {
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
return;
}
@@ -7980,9 +8130,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
// GEMM using local memory
// Current BK = 16, so ne00 % 16 == 0
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
src1t == GGML_TYPE_F32 &&
if (src1t == GGML_TYPE_F32 &&
ne00 % 16 == 0 &&
ne11 > 1) {
switch(src0t) {
@@ -7994,10 +8142,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
int batch_stride_b = ne10*ne11;
int batch_stride_d = ne0*ne1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
cl_mem mem_src0 = extra0->data_device;
cl_mem mem_src1 = extra1->data_device;
cl_ulong nb00_cont = nb00;
cl_ulong nb01_cont = nb01;
cl_ulong nb02_cont = nb02;
cl_ulong nb03_cont = nb03;
cl_ulong nb10_cont = nb10;
cl_ulong nb11_cont = nb11;
cl_ulong nb12_cont = nb12;
cl_ulong nb13_cont = nb13;
cl_ulong offset0_cont = offset0;
cl_ulong offset1_cont = offset1;
if (!ggml_is_contiguous(src0)) {
backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
nb00_cont, nb01_cont, nb02_cont, nb03_cont);
mem_src0 = backend_ctx->prealloc_src0.buffer;
offset0_cont = 0;
}
if (!ggml_is_contiguous(src1)) {
backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
nb10_cont, nb11_cont, nb12_cont, nb13_cont);
mem_src1 = backend_ctx->prealloc_src1.buffer;
offset1_cont = 0;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
@@ -8029,10 +8209,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
int batch_stride_b = ne10*ne11;
int batch_stride_d = ne0*ne1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
cl_mem mem_src0 = extra0->data_device;
cl_mem mem_src1 = extra1->data_device;
cl_ulong nb00_cont = nb00;
cl_ulong nb01_cont = nb01;
cl_ulong nb02_cont = nb02;
cl_ulong nb03_cont = nb03;
cl_ulong nb10_cont = nb10;
cl_ulong nb11_cont = nb11;
cl_ulong nb12_cont = nb12;
cl_ulong nb13_cont = nb13;
cl_ulong offset0_cont = offset0;
cl_ulong offset1_cont = offset1;
if (!ggml_is_contiguous(src0)) {
backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
nb00_cont, nb01_cont, nb02_cont, nb03_cont);
mem_src0 = backend_ctx->prealloc_src0.buffer;
offset0_cont = 0;
}
if (!ggml_is_contiguous(src1)) {
backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
nb10_cont, nb11_cont, nb12_cont, nb13_cont);
mem_src1 = backend_ctx->prealloc_src1.buffer;
offset1_cont = 0;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
@@ -8060,6 +8272,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
if (ne11 < 32) {
break;
}
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
break;
}
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
@@ -10008,6 +10224,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_glu;
break;
case GGML_OP_TRI:
if (!any_on_device) {
return false;
}
func = ggml_cl_tri;
break;
case GGML_OP_FILL:
if (!any_on_device) {
return false;
@@ -111,6 +111,10 @@ kernel void kernel_mul_mv_q6_K_f32(
int row = N_SIMDGROUP * r0 + get_sub_group_id();
if (row >= ne01) {
return;
}
int i12 = im%ne12;
int i13 = im/ne12;
+32
View File
@@ -0,0 +1,32 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// tri
//------------------------------------------------------------------------------
__kernel void kernel_tri_f32(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int n,
int ne0,
int ne1,
int tri_type
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
int idx = get_global_id(0);
if (idx >= n) return;
int i0 = idx % ne0;
int i1 = (idx / ne0) % ne1;
int keep = 0;
if (tri_type == 0) keep = (i0 >= i1);
else if (tri_type == 1) keep = (i0 > i1);
else if (tri_type == 2) keep = (i0 <= i1);
else keep = (i0 < i1);
dst[idx] = keep ? src0[idx] : 0.0f;
}
+3
View File
@@ -4109,6 +4109,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
+121 -101
View File
@@ -991,6 +991,8 @@ struct vk_mat_vec_id_push_constants {
uint32_t fusion_flags;
uint32_t nei0;
uint32_t ne11;
uint32_t expert_i1;
uint32_t nbi1;
};
struct vk_flash_attn_push_constants {
@@ -1516,6 +1518,15 @@ struct vk_quantize_q8_1_push_constants {
uint32_t num_blocks;
};
struct vk_op_flash_attn_split_k_reduce_push_constants {
uint32_t D;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t k_num;
uint32_t sinks;
};
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1802,7 +1813,6 @@ struct ggml_backend_vk_context {
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
vk_context_ref compute_ctx;
vk_context_ref transfer_ctx;
std::vector<vk_context_ref> tensor_ctxs;
@@ -1812,7 +1822,6 @@ struct ggml_backend_vk_context {
uint32_t pipeline_descriptor_set_requirements {};
vk_command_pool compute_cmd_pool;
vk_command_pool transfer_cmd_pool;
// number of additional consecutive nodes that are being fused with the
// node currently being processed
@@ -3178,15 +3187,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} \
} \
@@ -3980,7 +3989,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
if (device->subgroup_clustered && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
@@ -5647,7 +5656,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
ctx->almost_ready_fence = ctx->device->device.createFence({});
ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
if (vk_perf_logger_enabled) {
ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
@@ -8083,8 +8091,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
GGML_ASSERT(nei1 == 1);
const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];
@@ -8168,7 +8175,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -8226,7 +8233,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
uint32_t stride_batch_y = ne10*ne11;
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
}
const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
@@ -8262,23 +8269,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
}
// compute
const vk_mat_vec_id_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
fusion_flags,
(uint32_t)nei0, (uint32_t)ne11,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
d_ids,
},
pc, { groups_x, (uint32_t)nei0, groups_z });
// Loop over the batch dimension
for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
const vk_mat_vec_id_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
fusion_flags,
(uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
d_ids,
},
pc, { groups_x, (uint32_t)nei0, groups_z });
}
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
@@ -8292,7 +8301,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
ggml_tensor * dst = cgraph->nodes[node_idx];
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src2 = dst->src[2];
return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
}
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -8454,14 +8463,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(0);
}
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
gqa_ratio = qk_ratio;
N = gqa_ratio;
workgroups_y /= N;
workgroups_y /= gqa_ratio;
}
bool small_rows = N <= get_fa_num_small_rows(path);
@@ -8523,6 +8532,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
assert(pipeline);
// Compile early to initialize wg_denoms.
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
uint32_t split_kv = KV;
uint32_t split_k = 1;
@@ -8530,22 +8541,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
// Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0) {
// Try to use split_k when KV is large enough to be worth the overhead.
// Must either be a single batch or be using gqa, we can't mix the two.
if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
// Try to run two workgroups per SM.
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
split_k = CEIL_DIV(KV, split_kv);
workgroups_x = split_k;
}
}
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
// For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
// For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
GGML_ABORT("Requested preallocation size is too large");
}
@@ -8556,7 +8569,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
{
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
}
@@ -8605,7 +8617,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
workgroups_x *= pipeline->wg_denoms[0];
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
@@ -8613,15 +8625,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// there's no more than one tile of rows (i.e. workgroups_x would have been
// one). We reuse workgroups_x to mean the number of splits, so we need to
// cancel out the divide by wg_denoms[0].
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{split_k_buf, sinks_buf, dst_buf},
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
ctx->prealloc_split_k_need_sync = true;
} else {
if (gqa_ratio > 1) {
// When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
workgroups_x *= pipeline->wg_denoms[0];
}
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
pc, { workgroups_x, workgroups_y, workgroups_z });
@@ -11560,7 +11576,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
free(d_chk);
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
ggml_vk_destroy_buffer(d_X);
ggml_vk_destroy_buffer(d_Y);
@@ -12145,7 +12160,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
ggml_vk_submit(subctx, {});
ctx->submit_pending = true;
ggml_vk_synchronize(ctx);
GGML_ASSERT(ctx->compute_ctx.expired());
ggml_vk_ctx_begin(ctx->device, subctx);
ctx->compute_ctx = subctx;
}
if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
@@ -12163,6 +12180,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
ggml_vk_destroy_buffer(ctx->prealloc_y);
}
ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
ctx->prealloc_y_last_tensor_used = nullptr;
}
if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
@@ -12191,6 +12209,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
return false;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return false;
}
VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
ctx->semaphore_idx = 0;
@@ -12740,7 +12761,6 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
@@ -12769,7 +12789,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
// discard any unsubmitted command buffers
ctx->transfer_ctx.reset();
ctx->compute_ctx.reset();
// wait for any pending command buffers to finish
ggml_vk_synchronize(ctx);
@@ -12802,7 +12822,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
ctx->descriptor_sets.clear();
ctx->compute_cmd_pool.destroy(ctx->device->device);
ctx->transfer_cmd_pool.destroy(ctx->device->device);
if (vk_perf_logger_enabled) {
ctx->perf_logger->print_timings(true);
}
@@ -13074,34 +13093,34 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
vk_context transfer_ctx;
vk_context compute_ctx;
if (ctx->transfer_ctx.expired()) {
if (ctx->compute_ctx.expired()) {
// Initialize new transfer context
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->transfer_ctx = transfer_ctx;
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->compute_ctx = compute_ctx;
ggml_vk_ctx_begin(ctx->device, compute_ctx);
} else {
transfer_ctx = ctx->transfer_ctx.lock();
compute_ctx = ctx->compute_ctx.lock();
}
vk_buffer buf = buf_ctx->dev_buffer;
auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size);
bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size);
if (!ret) {
ggml_vk_ensure_sync_staging_buffer(ctx, size);
ggml_vk_sync_buffers(nullptr, transfer_ctx);
ggml_vk_sync_buffers(nullptr, compute_ctx);
vk::BufferCopy buffer_cpy;
buffer_cpy.srcOffset = 0;
buffer_cpy.dstOffset = dst_offset;
buffer_cpy.size = size;
transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys);
compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys);
ggml_vk_synchronize(ctx);
}
}
@@ -13113,34 +13132,34 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
vk_context transfer_ctx;
vk_context compute_ctx;
if (ctx->transfer_ctx.expired()) {
if (ctx->compute_ctx.expired()) {
// Initialize new transfer context
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->transfer_ctx = transfer_ctx;
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->compute_ctx = compute_ctx;
ggml_vk_ctx_begin(ctx->device, compute_ctx);
} else {
transfer_ctx = ctx->transfer_ctx.lock();
compute_ctx = ctx->compute_ctx.lock();
}
vk_buffer buf = buf_ctx->dev_buffer;
auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size);
bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
// If that failed, copy synchronously through a staging buffer
if (!ret) {
ggml_vk_ensure_sync_staging_buffer(ctx, size);
ggml_vk_sync_buffers(nullptr, transfer_ctx);
ggml_vk_sync_buffers(nullptr, compute_ctx);
vk::BufferCopy buffer_cpy;
buffer_cpy.srcOffset = src_offset;
buffer_cpy.dstOffset = 0;
buffer_cpy.size = size;
transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys);
compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
ggml_vk_synchronize(ctx);
}
}
@@ -13152,21 +13171,21 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
vk_context transfer_ctx;
vk_context compute_ctx;
if (ctx->transfer_ctx.expired()) {
if (ctx->compute_ctx.expired()) {
// Initialize new transfer context
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->transfer_ctx = transfer_ctx;
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->compute_ctx = compute_ctx;
ggml_vk_ctx_begin(ctx->device, compute_ctx);
} else {
transfer_ctx = ctx->transfer_ctx.lock();
compute_ctx = ctx->compute_ctx.lock();
}
vk_buffer src_buf = src_buf_ctx->dev_buffer;
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
return true;
}
@@ -13176,19 +13195,19 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_
static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
VK_LOG_DEBUG("ggml_vk_synchronize()");
bool do_transfer = !ctx->transfer_ctx.expired();
bool do_transfer = !ctx->compute_ctx.expired();
vk_context transfer_ctx;
vk_context compute_ctx;
if (do_transfer) {
transfer_ctx = ctx->transfer_ctx.lock();
compute_ctx = ctx->compute_ctx.lock();
ggml_vk_ctx_end(transfer_ctx);
ggml_vk_ctx_end(compute_ctx);
for (auto& cpy : transfer_ctx->in_memcpys) {
for (auto& cpy : compute_ctx->in_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
ggml_vk_submit(transfer_ctx, {});
ggml_vk_submit(compute_ctx, {});
ctx->submit_pending = true;
}
@@ -13202,10 +13221,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
}
if (do_transfer) {
for (auto& cpy : transfer_ctx->out_memcpys) {
for (auto& cpy : compute_ctx->out_memcpys) {
memcpy(cpy.dst, cpy.src, cpy.n);
}
ctx->transfer_ctx.reset();
ctx->compute_ctx.reset();
}
}
@@ -13645,7 +13664,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
int last_node = cgraph->n_nodes - 1;
// If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) {
while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {
last_node -= 1;
}
@@ -13874,6 +13893,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ggml_vk_submit(compute_ctx, ctx->device->fence);
VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
ctx->device->device.resetFences({ ctx->device->fence });
ctx->compute_ctx.reset();
// Get the results and pass them to the logger
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
@@ -14160,15 +14180,15 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
vk_event *vkev = (vk_event *)event->context;
vk_context transfer_ctx;
vk_context compute_ctx;
if (ctx->transfer_ctx.expired()) {
if (ctx->compute_ctx.expired()) {
// Initialize new transfer context
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->transfer_ctx = transfer_ctx;
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->compute_ctx = compute_ctx;
ggml_vk_ctx_begin(ctx->device, compute_ctx);
} else {
transfer_ctx = ctx->transfer_ctx.lock();
compute_ctx = ctx->compute_ctx.lock();
}
// the backend interface doesn't have an explicit reset, so reset it here
@@ -14176,13 +14196,13 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
ctx->device->device.resetEvent(vkev->event);
ctx->device->device.resetFences({ vkev->fence });
ggml_vk_set_event(transfer_ctx, vkev->event);
ggml_vk_set_event(compute_ctx, vkev->event);
ggml_vk_ctx_end(transfer_ctx);
ggml_vk_ctx_end(compute_ctx);
ggml_vk_submit(transfer_ctx, {vkev->fence});
ggml_vk_submit(compute_ctx, {vkev->fence});
ctx->submit_pending = true;
ctx->transfer_ctx.reset();
ctx->compute_ctx.reset();
}
static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
@@ -14190,20 +14210,20 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
vk_event *vkev = (vk_event *)event->context;
vk_context transfer_ctx;
vk_context compute_ctx;
if (ctx->transfer_ctx.expired()) {
if (ctx->compute_ctx.expired()) {
// Initialize new transfer context
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->transfer_ctx = transfer_ctx;
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
ctx->compute_ctx = compute_ctx;
ggml_vk_ctx_begin(ctx->device, compute_ctx);
} else {
transfer_ctx = ctx->transfer_ctx.lock();
compute_ctx = ctx->compute_ctx.lock();
}
ggml_vk_wait_events(transfer_ctx, {vkev->event});
ggml_vk_ctx_end(transfer_ctx);
ctx->transfer_ctx.reset();
ggml_vk_wait_events(compute_ctx, {vkev->event});
ggml_vk_ctx_end(compute_ctx);
ctx->compute_ctx.reset();
}
// TODO: enable async and synchronize
@@ -53,7 +53,7 @@ void main() {
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
@@ -101,9 +101,9 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
@@ -320,7 +320,8 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
@@ -332,7 +333,7 @@ void main() {
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -378,7 +379,7 @@ void main() {
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
@@ -165,7 +165,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
void init_indices()
@@ -173,12 +173,19 @@ void init_indices()
N = p.N;
KV = p.KV;
i = gl_WorkGroupID.x;
split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
// batch and split_k share gl_WorkGroupID.x
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
split_k_index = gl_WorkGroupID.x % p.k_num;
} else if (p.gqa_ratio > 1) {
i = 0;
gqa_iq1 = gl_WorkGroupID.x;
split_k_index = 0;
} else {
i = gl_WorkGroupID.x;
gqa_iq1 = 0;
split_k_index = 0;
}
Tr = CEIL_DIV(N, Br);
@@ -90,7 +90,7 @@ void main() {
barrier();
}
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
@@ -141,9 +141,9 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
@@ -370,7 +370,8 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
@@ -382,7 +383,7 @@ void main() {
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -428,7 +429,7 @@ void main() {
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
@@ -111,7 +111,7 @@ void main() {
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
@@ -138,9 +138,9 @@ void main() {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
uint32_t m_offset = 0;
uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
@@ -272,10 +272,11 @@ void main() {
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
@@ -325,7 +326,7 @@ void main() {
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
@@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint ne1;
uint ne2;
uint ne3;
uint k_num;
uint sinks;
@@ -24,15 +25,15 @@ void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint iq3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.z % p.ne2;
const uint i3 = gl_WorkGroupID.z / p.ne2;
uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
uint lm_stride = N * 2;
uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;
uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;
uint lm_stride = p.ne1 * 2;
// Compute the max m value for the row
float m_max = -1.0/0.0;
@@ -99,7 +100,7 @@ void main() {
if (d < D) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
@@ -115,6 +116,6 @@ void main() {
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
O = clamp(O, -FLT_MAX, FLT_MAX);
data_d[iq3 * D * N + D * n + d] = O;
data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;
}
}
@@ -29,6 +29,8 @@ layout (push_constant) uniform parameter
#ifdef MUL_MAT_ID
uint nei0;
uint ne11;
uint expert_i1;
uint nbi1;
#else
uint ne02;
uint ne12;
@@ -43,7 +45,7 @@ uint expert_id;
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.y;
const uint expert_i0 = gl_GlobalInvocationID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
#endif
@@ -60,7 +62,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
batch_idx_a = i03 * p.ne02 + i02;
}
#else
expert_id = data_ids[expert_idx];
expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
#endif
a_offset =
@@ -71,13 +73,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#endif
b_offset =
#ifdef MUL_MAT_ID
(expert_idx % p.ne11) * p.stride_b;
(expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
#else
batch_idx * p.batch_stride_b;
#endif
d_offset =
#ifdef MUL_MAT_ID
expert_idx * p.stride_d;
expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
#else
batch_idx * p.batch_stride_d;
#endif
@@ -103,12 +105,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -158,12 +160,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -203,12 +205,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+335 -36
View File
@@ -9,12 +9,28 @@
#define GGML_WEBGPU_F16_SIZE_BYTES 2
#define GGML_WEBGPU_F32_SIZE_BYTES 4
#define GGML_WEBGPU_I32_SIZE_BYTES 4
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
#define GGML_WEBGPU_KV_SEQ_PAD 256u
struct ggml_webgpu_flash_attn_shader_lib_context {
#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
void * decisions;
};
// Same hash combine function as in boost
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
/** FlashAttention */
struct ggml_webgpu_flash_attn_pipeline_key {
ggml_type kv_type;
uint32_t head_dim_qk;
uint32_t head_dim_v;
@@ -22,11 +38,35 @@ struct ggml_webgpu_flash_attn_shader_lib_context {
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
size_t wg_mem_limit_bytes;
uint32_t max_subgroup_size;
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap;
}
};
struct ggml_webgpu_flash_attn_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.kv_type);
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.kv_direct);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
return seed;
}
};
struct ggml_webgpu_flash_attn_shader_lib_context {
ggml_webgpu_flash_attn_pipeline_key key;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
size_t wg_mem_limit_bytes;
uint32_t max_subgroup_size;
};
struct ggml_webgpu_flash_attn_shader_decisions {
@@ -35,12 +75,6 @@ struct ggml_webgpu_flash_attn_shader_decisions {
uint32_t wg_size = 0;
};
struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
ggml_webgpu_flash_attn_shader_decisions decisions;
};
// This is exposed because it's necessary in supports_op
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
@@ -66,15 +100,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
}
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
const size_t limit_bytes = context.wg_mem_limit_bytes;
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
const size_t limit_bytes = context.wg_mem_limit_bytes;
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes =
(context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
if (!context.kv_direct) {
bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
if (!context.key.kv_direct) {
bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
}
if (context.has_mask) {
if (context.key.has_mask) {
bytes_per_kv += q_tile;
}
bytes_per_kv += q_tile;
@@ -90,7 +125,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
std::vector<std::string> defines;
std::string variant = "flash_attn";
switch (context.kv_type) {
switch (context.key.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
@@ -106,32 +141,31 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
variant += std::string("_") + ggml_type_name(context.kv_type);
variant += std::string("_") + ggml_type_name(context.key.kv_type);
if (context.has_mask) {
if (context.key.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
if (context.has_sinks) {
if (context.key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (context.uses_logit_softcap) {
if (context.key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
if (context.kv_direct) {
if (context.key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.head_dim_v);
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
// For now these are not part of the variant name
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
@@ -141,7 +175,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
if (context.kv_direct) {
if (context.key.kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
@@ -158,11 +192,276 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
result.decisions.q_tile = q_tile;
result.decisions.kv_tile = kv_tile;
result.decisions.wg_size = wg_size;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions();
decisions->q_tile = q_tile;
decisions->kv_tile = kv_tile;
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
/** Generic **/
struct ggml_webgpu_generic_shader_lib_context {
int vec4;
uint32_t max_wg_size;
};
struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_generic_shader_lib_context & context,
const std::string & base_variant) {
std::vector<std::string> defines;
std::string variant = base_variant;
if (context.vec4) {
defines.push_back("VEC4");
variant += "_vec";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
};
struct ggml_webgpu_pad_pipeline_key_hash {
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.circular);
return seed;
}
};
struct ggml_webgpu_pad_shader_lib_context {
ggml_webgpu_pad_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_pad_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "pad";
if (context.key.circular) {
defines.push_back("CIRCULAR");
variant += "_circular";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
/** Argsort **/
struct ggml_webgpu_argsort_shader_lib_context {
uint32_t max_wg_size;
size_t wg_mem_limit_bytes;
int32_t order;
};
struct ggml_webgpu_argsort_shader_decisions {
uint32_t wg_size = 0;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_argsort_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "argsort";
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
variant += std::string("_order") + std::to_string(context.order);
uint32_t wg_size = 1;
while (wg_size * 2 <= context.max_wg_size &&
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
wg_size *= 2;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_argsort_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "argsort_merge";
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
variant += std::string("_order") + std::to_string(context.order);
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
/** Set Rows **/
struct ggml_webgpu_set_rows_pipeline_key {
int dst_type;
int vec4;
int i64_idx;
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
}
};
struct ggml_webgpu_set_rows_pipeline_key_hash {
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.vec4);
ggml_webgpu_hash_combine(seed, key.i64_idx);
return seed;
}
};
struct ggml_webgpu_set_rows_shader_lib_context {
ggml_webgpu_set_rows_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_set_rows_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "set_rows";
switch (context.key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
variant += "_dstf32";
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
variant += "_dstf16";
break;
default:
GGML_ABORT("Unsupported dst type for set_rows shader");
}
if (context.key.vec4) {
defines.push_back("VEC4");
variant += "_vec";
}
if (context.key.i64_idx) {
defines.push_back("I64_IDX");
variant += "_i64idx";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
struct ggml_webgpu_unary_pipeline_key {
int type;
int op;
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
bool inplace;
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
}
};
struct ggml_webgpu_unary_pipeline_key_hash {
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.is_unary);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
struct ggml_webgpu_unary_shader_lib_context {
ggml_webgpu_unary_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_unary_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
ggml_op_name((ggml_op) context.key.op);
// Operation-specific behavior
defines.push_back(variant);
switch (context.key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for unary shader");
}
if (context.key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,72 @@
@group(0) @binding(0)
#ifdef VEC4
var<storage, read_write> src: array<vec4<f32>>;
#define VEC_SIZE 4
#else
var<storage, read_write> src: array<f32>;
#define VEC_SIZE 1
#endif
@group(0) @binding(1)
var<storage, read_write> dst: array<i32>;
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
ne0: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
const FLOAT_MIN: f32 = -1.0e9;
struct Pair {
value: f32,
index: i32
};
var<workgroup> shared_max: array<Pair, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let row_idx = params.offset_src + wid.x * params.ne0;
var local_pair = Pair(FLOAT_MIN, -1);
#ifdef VEC4
for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) {
let vec_val = src[row_idx / VEC_SIZE + col];
for (var v = 0u; v < VEC_SIZE; v++) {
let val = vec_val[v];
if (val >= local_pair.value) {
local_pair = Pair(val, i32(col * VEC_SIZE + v));
}
}
}
#else
for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
if (src[row_idx + col] >= local_pair.value) {
local_pair = Pair(src[row_idx + col], i32(col));
}
}
#endif
shared_max[lid.x] = local_pair;
workgroupBarrier();
var offset: u32 = WG_SIZE >> 1;
while (offset > 0) {
if (lid.x < offset) {
let a = shared_max[lid.x];
let b = shared_max[lid.x + offset];
if (b.value > a.value) {
shared_max[lid.x] = b;
} else if (b.value == a.value && b.index > a.index) {
shared_max[lid.x] = b;
}
}
workgroupBarrier();
offset >>= 1;
}
if (lid.x == 0u) {
dst[params.offset_dst + wid.x] = shared_max[0].index;
}
}
@@ -0,0 +1,106 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<i32>;
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// src/dst dimensions
src_ne0: u32,
ne1: u32,
ne2: u32,
ne0: u32,
top_k: u32,
npr: u32, // tiles per row
nrows: u32
};
@group(0) @binding(2)
var<uniform> params: Params;
var<workgroup> shmem_idx: array<u32, WG_SIZE>;
#if ORDER == 0
#define EXTREME_VALUE 1e30
#define SWAP_COMPARE_UP >
#define SWAP_COMPARE_DOWN <
#else
#define EXTREME_VALUE -1e30
#define SWAP_COMPARE_UP <
#define SWAP_COMPARE_DOWN >
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let linear = wid.x + wid.y * num_wg.x;
// guard against overprovisioned workgroups
if (linear >= params.npr * params.nrows) {
return;
}
let tile = linear % params.npr;
var row = linear / params.npr;
let i3 = row / (params.ne2 * params.ne1);
row = row % (params.ne2 * params.ne1);
let i2 = row / params.ne1;
let i1 = row % params.ne1;
let row_base = params.offset_src +
i1 * params.stride_src1 +
i2 * params.stride_src2 +
i3 * params.stride_src3;
let tile_base = tile * WG_SIZE;
let idx = tile_base + lid.x;
shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);
workgroupBarrier();
var k = 2u;
while (k <= WG_SIZE) {
var j = k >> 1;
while (j > 0) {
let ixj = lid.x ^ j;
if (ixj > lid.x) {
let dir_up = (lid.x & k) == 0;
let a_idx = shmem_idx[lid.x];
let b_idx = shmem_idx[ixj];
let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);
let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);
let should_swap = select(
(a_val SWAP_COMPARE_DOWN b_val),
(a_val SWAP_COMPARE_UP b_val),
dir_up);
if (should_swap) {
shmem_idx[lid.x] = b_idx;
shmem_idx[ixj] = a_idx;
}
}
workgroupBarrier();
j >>= 1;
}
k <<= 1;
}
let out_idx = tile * params.top_k + lid.x;
if (out_idx < params.ne0 && lid.x < params.top_k) {
let row_dst = params.offset_dst +
i1 * params.stride_dst1 +
i2 * params.stride_dst2 +
i3 * params.stride_dst3;
dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);
}
}
@@ -0,0 +1,134 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> idx_in: array<i32>;
@group(0) @binding(2)
var<storage, read_write> idx_out: array<i32>;
struct Params {
offset_src: u32, // in elements
offset_in: u32, // in elements
offset_out: u32, // in elements
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_idx1: u32,
stride_idx2: u32,
stride_idx3: u32,
stride_out1: u32,
stride_out2: u32,
stride_out3: u32,
ne0: u32,
ne1: u32,
ne2: u32,
top_k: u32,
len: u32,
nm: u32,
nrows: u32
};
@group(0) @binding(3)
var<uniform> params: Params;
fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
let a_val = src[row_base + u32(a_idx)];
let b_val = src[row_base + u32(b_idx)];
#if ORDER == 0
return a_val <= b_val;
#else
return a_val >= b_val;
#endif
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let linear = wid.x + wid.y * num_wg.x;
// guard against overprovisioned workgroups
if (linear >= params.nm * params.nrows) {
return;
}
let start = (linear % params.nm) * params.len * 2;
let len0 = min(params.len, params.ne0 - start);
let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
let len1 = min(params.len, rem1);
let total = len0 + len1;
let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
let k0 = lid.x * chunk;
let k1 = min(min(k0 + chunk, total), params.top_k);
// guard against overprovisioned threads
if (k0 >= params.top_k || k0 >= total) {
return;
}
var row = linear / params.nm;
let i3 = row / (params.ne2 * params.ne1);
row = row % (params.ne2 * params.ne1);
let i2 = row / params.ne1;
let i1 = row % params.ne1;
let row_src = params.offset_src +
i1 * params.stride_src1 +
i2 * params.stride_src2 +
i3 * params.stride_src3;
let row_in = params.offset_in +
i1 * params.stride_idx1 +
i2 * params.stride_idx2 +
i3 * params.stride_idx3;
let row_out = params.offset_out +
i1 * params.stride_out1 +
i2 * params.stride_out2 +
i3 * params.stride_out3;
var low: u32 = select(0, k0 - len1, k0 > len1);
var high: u32 = min(k0, len0);
while (low < high) {
let mid = (low + high) >> 1;
let idx0 = idx_in[row_in + start + mid];
let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
if (take_left(idx0, idx1, row_src)) {
low = mid + 1;
} else {
high = mid;
}
}
var i = low;
var j = k0 - i;
var k = k0;
while (k < k1) {
var take_l = false;
if (i >= len0) {
take_l = false;
} else if (j >= len1) {
take_l = true;
} else {
let idx0 = idx_in[row_in + start + i];
let idx1 = idx_in[row_in + start + params.len + j];
take_l = take_left(idx0, idx1, row_src);
}
let out_idx = select(
idx_in[row_in + start + params.len + j],
idx_in[row_in + start + i],
take_l);
idx_out[row_out + start + k] = out_idx;
i = select(i, i + 1, take_l);
j = select(j + 1, j, take_l);
k += 1;
}
}
@@ -7,6 +7,12 @@
"DST_TYPE": "f32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "i32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
@@ -0,0 +1,66 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
ne0: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
var<workgroup> shared_sum: array<f32, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let row_idx = params.offset_src + wid.x * params.ne0;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
var local_sum: f32 = 0.0;
for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
local_sum += src[row_idx + col];
}
shared_sum[lid.x] = local_sum;
workgroupBarrier();
// upsweep
var offset = 1u;
while (offset < WG_SIZE) {
let idx = (lid.x + 1) * offset * 2 - 1;
if (idx < WG_SIZE) {
shared_sum[idx] = shared_sum[idx] + shared_sum[idx - offset];
}
workgroupBarrier();
offset <<= 1;
}
// set last to 0 for exclusive sum
if (lid.x == 0) {
shared_sum[WG_SIZE - 1] = 0.0;
}
workgroupBarrier();
// downsweep
offset = WG_SIZE >> 1;
while (offset > 0) {
let idx = (lid.x + 1) * offset * 2 - 1;
if (idx < WG_SIZE) {
let t = shared_sum[idx - offset];
shared_sum[idx - offset] = shared_sum[idx];
shared_sum[idx] = shared_sum[idx] + t;
}
workgroupBarrier();
offset = offset >> 1;
}
// shared_sum[lid] is exclusive prefix sum up to this thread.
var running_sum = shared_sum[lid.x];
for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
running_sum += src[row_idx + col];
dst[params.offset_dst + wid.x * params.ne0 + col] = running_sum;
}
}
@@ -0,0 +1,86 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
src_ne3: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32,
dst_ne3: u32,
// Pad sizes (in elements)
lp0: u32,
rp0: u32,
lp1: u32,
rp1: u32,
lp2: u32,
rp2: u32,
lp3: u32,
rp3: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
fn wrap_around(idx: i32, n: u32) -> u32 {
return u32(idx + i32(n)) % n;
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0;
let i3 = i / dst_plane;
i = i % dst_plane;
let i2 = i / (params.dst_ne1 * params.dst_ne0);
i = i % (params.dst_ne1 * params.dst_ne0);
let i1 = i / params.dst_ne0;
let i0 = i % params.dst_ne0;
var value: f32 = 0.0;
#ifdef CIRCULAR
let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0);
let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1);
let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2);
let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3);
let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 +
ci2 * params.stride_src2 + ci3 * params.stride_src3;
value = src[params.offset_src + circular_src_idx];
#else
let is_src =
(i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) &&
(i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) &&
(i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) &&
(i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3);
if (is_src) {
let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 +
(i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3;
value = src[params.offset_src + src_idx];
}
#endif
dst[params.offset_dst + gid.x] = value;
}
@@ -1,41 +1,37 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f16_vec",
"REPLS": {
"TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f16>",
"VEC_SIZE": 4
}
},
{
"SHADER_SUFFIX": "f16",
"REPLS": {
"TYPE" : "f32",
"DST_TYPE": "f16",
"VEC_SIZE": 1
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#ifdef DST_F32
#define DST_INNER_TYPE f32
#else
#define DST_INNER_TYPE f16
#endif
#ifdef VEC4
#define SRC_TYPE vec4<f32>
#define DST_TYPE vec4<DST_INNER_TYPE>
#define VEC_SIZE 4
#else
#define SRC_TYPE f32
#define DST_TYPE DST_INNER_TYPE
#define VEC_SIZE 1
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;
var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> idx: array<u32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
var<storage, read_write> dst: array<DST_TYPE>;
#ifdef I64_IDX
@group(0) @binding(3)
var<storage, read_write> error: atomic<u32>;
#define PARAMS_BINDING 4
#else
#define PARAMS_BINDING 3
#endif
struct Params {
offset_src: u32, // in elements
@@ -66,18 +62,17 @@ struct Params {
idx2: u32,
};
@group(0) @binding(4)
@group(0) @binding(PARAMS_BINDING)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) {
if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {
return;
}
// getting the row from gid
let elems_per_row = params.ne0 / {{VEC_SIZE}};
let elems_per_row = params.ne0 / VEC_SIZE;
var i = gid.x / elems_per_row;
let i_src3 = i / (params.ne2 * params.n_rows);
@@ -90,9 +85,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_idx1 = i_src2 % params.idx1;
let i_idx0 = i_src1;
#ifdef I64_IDX
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
let idx_high_val = idx[idx_high];
let idx_val = idx[idx_high];
let idx_low_val = idx[idx_high + 1];
if (idx_low_val != 0) {
@@ -100,13 +96,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
atomicStore(&error, 1);
return;
}
#else
let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
let idx_val = idx[idx_i];
#endif
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
let col_idx = (gid.x % elems_per_row);
dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]);
dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
}
#end(SHADER)
@@ -0,0 +1,55 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
ne0: u32,
ne1: u32,
ne2: u32
};
@group(0) @binding(2)
var<uniform> params: Params;
var<workgroup> shared_sum: array<f32, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
var i = wid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
var local_sum: f32 = 0.0;
for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
local_sum += src[i_src_row + col];
}
shared_sum[lid.x] = local_sum;
workgroupBarrier();
// reduce within workgroup
var offset: u32 = WG_SIZE >> 1;
while (offset > 0) {
if (lid.x < offset) {
shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset];
}
workgroupBarrier();
offset >>= 1;
}
if (lid.x == 0) {
dst[params.offset_dst + wid.x] = shared_sum[0];
}
}
@@ -0,0 +1,179 @@
#ifdef TYPE_F16
enable f16;
#define TYPE f16
#else
#define TYPE f32
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<TYPE>;
#ifndef INPLACE
@group(0) @binding(1)
var<storage, read_write> dst: array<TYPE>;
#define PARAMS_BINDING 2
#else
#define PARAMS_BINDING 1
#endif
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
// Logical shapes
ne0: u32,
ne1: u32,
ne2: u32,
#ifdef CLAMP
clamp_min: f32,
clamp_max: f32,
#endif
#ifdef FILL
fill_val: f32,
#endif
#ifdef XIELU
alpha_n: f32,
alpha_p: f32,
beta: f32,
eps: f32,
#endif
};
@group(0) @binding(PARAMS_BINDING)
var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
#ifdef ABS
let res = abs(src[params.offset_src + src_idx]);
#endif
#ifdef SGN
let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef NEG
let res = -src[params.offset_src + src_idx];
#endif
#ifdef STEP
let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0));
#endif
#ifdef TANH
let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913));
#endif
#ifdef RELU
let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef ELU
let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef HARDSIGMOID
let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
#endif
#ifdef SIGMOID
let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx]));
#endif
#ifdef SILU
let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
#endif
#ifdef EXP
let res = exp(src[params.offset_src + src_idx]);
#endif
#ifdef LOG
let res = TYPE(log(f32(src[params.offset_src + src_idx])));
#endif
#ifdef CLAMP
let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max));
#endif
#ifdef FILL
let res = TYPE(params.fill_val);
#endif
#ifdef HARDSWISH
let res = src[params.offset_src + src_idx] *
min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
#endif
#ifdef GELU
let res = 0.5 * src[params.offset_src + src_idx] *
(1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
(src[params.offset_src + src_idx] +
0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
-9.010913, 9.010913)));
#endif
#ifdef GELU_QUICK
let res = src[params.offset_src + src_idx] * 0.5 *
(1.0 + tanh(clamp(0.79788456 *
(src[params.offset_src + src_idx] +
0.044715 * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
-9.010913, 9.010913)));
#endif
#ifdef GELU_ERF
let res = 0.5 * src[params.offset_src + src_idx] *
(1.0 + tanh(clamp(0.79788456 *
(src[params.offset_src + src_idx] +
0.044715 * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
-9.010913, 9.010913)));
#endif
#ifdef XIELU
let res =
select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
src[params.offset_src + src_idx]) *
TYPE(params.alpha_n) +
TYPE(params.beta) * src[params.offset_src + src_idx],
TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] +
TYPE(params.beta) * src[params.offset_src + src_idx],
src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef SOFTPLUS
let src_f32 = f32(src[params.offset_src + src_idx]);
let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
#endif
#ifdef EXPM1
let res = exp(src[params.offset_src + src_idx]) - 1.0;
#endif
#ifdef FLOOR
let res = floor(src[params.offset_src + src_idx]);
#endif
#ifdef CEIL
let res = ceil(src[params.offset_src + src_idx]);
#endif
#ifdef ROUND
let src_f32 = f32(src[params.offset_src + src_idx]);
let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0);
let res = TYPE(result);
#endif
#ifdef TRUNC
let res = trunc(src[params.offset_src + src_idx]);
#endif
#ifdef INPLACE
src[params.offset_src + src_idx] = res;
#else
dst[params.offset_dst + gid.x] = res;
#endif
}
@@ -1,483 +0,0 @@
#define(REPL_TEMPLATES)
{
"XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);",
"ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);",
"SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);",
"NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];",
"STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));",
"TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
"RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);",
"ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);",
"HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
"SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));",
"SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));",
"EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);",
"HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
"GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
"GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
"GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
"CEIL_FUNC": "{{MUTATE}}[dst_i] = ceil(src[src_i]);"
}
#end(REPL_TEMPLATES)
#define(VARIANTS)
[
{
"SHADER_NAME": "abs_f32",
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "abs_f16",
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "abs_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "abs_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sgn_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sgn_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sgn_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sgn_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "neg_f32",
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "neg_f16",
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "neg_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "neg_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "step_f32",
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "step_f16",
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "step_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "step_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "tanh_f32",
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "tanh_f16",
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "tanh_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "tanh_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "elu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "elu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "elu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "elu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "relu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "relu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "relu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "relu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sigmoid_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sigmoid_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sigmoid_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sigmoid_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "silu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "silu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "silu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "silu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "exp_f32",
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "exp_f16",
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "exp_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "exp_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_f32",
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_f16",
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardswish_f32",
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardswish_f16",
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardswish_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardswish_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "xielu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "xielu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "xielu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "xielu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "ceil_f32",
"REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "ceil_f16",
"REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "ceil_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "ceil_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(INPLACE)
#decl(NOT_INPLACE)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NOT_INPLACE)
#end(DECLS)
#define(SHADER)
enable f16;
fn update(dst_i: u32, src_i: u32) {
{{FUNC}}
}
@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;
DECLS
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32,
{{EXT_PARAMS}}
};
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
let i2 = i / (params.src_ne1 * params.src_ne0);
i = i % (params.src_ne1 * params.src_ne0);
let i1 = i / params.src_ne0;
let i0 = i % params.src_ne0;
var j = gid.x;
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
let j2 = j / (params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne1 * params.dst_ne0);
let j1 = j / params.dst_ne0;
let j0 = j % params.dst_ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
j2 * params.stride_dst2 + j3 * params.stride_dst3;
update(params.offset_dst + dst_idx, params.offset_src + src_idx);
}
#end(SHADER)
+6 -1
View File
@@ -58,6 +58,10 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
bool ok = ggml_zdnn_compute_forward(ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: unsupported op %s (%s)\n",
@@ -368,7 +372,8 @@ static size_t ggml_backend_zdnn_buffer_type_get_alignment(ggml_backend_buffer_ty
}
static bool ggml_backend_zdnn_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return true;
/* while it resides in host memory, additional transformation is needed */
return false;
GGML_UNUSED(buft);
}
+4
View File
@@ -211,6 +211,10 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
switch (node->op) {
case GGML_OP_MUL_MAT:
ggml_zendnn_compute_forward_mul_mat(ctx, node);
+52 -18
View File
@@ -3441,7 +3441,8 @@ struct ggml_tensor * ggml_cast(
result->op = GGML_OP_CPY;
result->src[0] = a;
result->src[1] = result;
result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some
// backends for consistency with ggml_cpy_impl() above
return result;
}
@@ -6725,20 +6726,35 @@ static void ggml_compute_backward(
GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
}
static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
// check if already visited
size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) {
if (node->op != GGML_OP_NONE && compute) {
node->flags |= GGML_TENSOR_FLAG_COMPUTE;
}
const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
// This is the first time we see this node in the current graph.
cgraph->visited_hash_set.keys[node_hash_pos] = node;
ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
cgraph->use_counts[node_hash_pos] = 0;
} else {
if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
// already visited
if (compute) {
// update the compute flag regardless
for (int i = 0; i < GGML_MAX_SRC; ++i) {
struct ggml_tensor * src = node->src[i];
if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) {
ggml_visit_parents_graph(cgraph, src, true);
}
}
}
return node_hash_pos;
}
// This is the first time we see this node in the current graph.
cgraph->visited_hash_set.keys[node_hash_pos] = node;
ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
cgraph->use_counts[node_hash_pos] = 0;
for (int i = 0; i < GGML_MAX_SRC; ++i) {
const int k =
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
@@ -6747,7 +6763,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor
struct ggml_tensor * src = node->src[k];
if (src) {
size_t src_hash_pos = ggml_visit_parents(cgraph, src);
const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute);
// Update the use count for this operand.
cgraph->use_counts[src_hash_pos]++;
@@ -6778,17 +6794,17 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor
return node_hash_pos;
}
static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) {
if (!expand) {
// TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
ggml_graph_clear(cgraph);
}
const int n0 = cgraph->n_nodes;
const int n_old = cgraph->n_nodes;
ggml_visit_parents(cgraph, tensor);
ggml_visit_parents_graph(cgraph, tensor, compute);
const int n_new = cgraph->n_nodes - n0;
const int n_new = cgraph->n_nodes - n_old;
GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
if (n_new > 0) {
@@ -6797,8 +6813,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten
}
}
struct ggml_tensor * ggml_build_forward_select(
struct ggml_cgraph * cgraph,
struct ggml_tensor ** tensors,
int n_tensors,
int idx) {
GGML_ASSERT(idx >= 0 && idx < n_tensors);
for (int i = 0; i < n_tensors; i++) {
ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false);
}
return tensors[idx];
}
void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
ggml_build_forward_impl(cgraph, tensor, true);
ggml_build_forward_impl(cgraph, tensor, true, true);
}
void ggml_build_backward_expand(
@@ -7229,6 +7259,10 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
return false;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return false;
}
if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
continue;
}
@@ -7310,7 +7344,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node,
label);
}
void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) {
char color[16];
FILE * fp = ggml_fopen(filename, "w");
@@ -7331,7 +7365,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
snprintf(color, sizeof(color), "yellow");
} else if (grad) {
if (ggml_graph_find(gf, node)) {
if (ggml_graph_find(cgraph, node)) {
snprintf(color, sizeof(color), "green");
} else {
snprintf(color, sizeof(color), "lightblue");
+1 -1
View File
@@ -734,7 +734,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
FILE * file = ggml_fopen(fname, "rb");
if (!file) {
GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno));
return nullptr;
}
+1 -1
View File
@@ -3,7 +3,7 @@ pytest~=8.3.3
huggingface_hub>=0.34.0,<1.0
matplotlib~=3.10.0
numpy~=1.26.4
openai~=1.55.3
openai~=2.14.0
pandas~=2.2.3
prometheus-client~=0.20.0
requests~=2.32.3
+1
View File
@@ -24,6 +24,7 @@ add_library(llama
llama-kv-cache-iswa.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-hybrid-iswa.cpp
llama-memory-recurrent.cpp
llama-mmap.cpp
llama-model-loader.cpp
+112
View File
@@ -7,6 +7,7 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
#include <cassert>
@@ -510,6 +511,76 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
return res;
}
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
const auto * attn_ctx = mctx->get_attn();
// base tensors may not be allocated if there are no non-SWA attention layers
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
}
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
int32_t * data = (int32_t *) inp_rs->s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->get_recr()->s_copy(i);
}
}
}
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
const auto * attn_ctx = mctx->get_attn();
// base tensors may not be allocated if there are no non-SWA attention layers
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
}
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
}
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
res &= inp_rs->head == mctx->get_recr()->get_head();
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
return res;
}
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
// set the inputs only for the active samplers in the current ubatch
std::unordered_set<llama_seq_id> active_samplers;
@@ -2056,6 +2127,47 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
// build iswa attention input
const auto * attn_ctx = mctx_cur->get_attn();
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
{
const auto n_kv = attn_ctx->get_base()->get_n_kv();
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp_attn->self_kq_mask);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}
{
const auto n_kv = attn_ctx->get_swa()->get_n_kv();
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp_attn->self_kq_mask_swa);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
}
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
}
void llm_graph_context::build_dense_out(
ggml_tensor * dense_2,
ggml_tensor * dense_3) const {
+31
View File
@@ -24,6 +24,7 @@ class llama_kv_cache_context;
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
class llama_memory_hybrid_iswa_context;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@@ -397,6 +398,34 @@ public:
const llama_memory_hybrid_context * mctx;
};
class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid_iswa(
const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_iswa_context * mctx) :
inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)),
cparams(cparams),
mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid_iswa() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_cparams cparams;
const llama_memory_hybrid_iswa_context * mctx;
};
class llm_graph_input_sampling : public llm_graph_input_i {
public:
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
@@ -881,6 +910,8 @@ struct llm_graph_context {
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
//
// pooling
//
-36
View File
@@ -200,42 +200,6 @@ uint32_t llama_hparams::n_layer_kv() const {
return res;
}
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
if (p0 < pos_chunk_start) {
return true;
}
} break;
case LLAMA_SWA_TYPE_SYMMETRIC:
{
const int32_t half_n_swa = (int32_t) n_swa / 2;
const int32_t pos_diff = p1 - p0;
// Mask if outside the symmetric window
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
return true;
}
} break;
}
return false;
}
bool llama_hparams::use_mrope() const {
return rope_sections[0] > 0 && rope_sections[1] > 0;
}

Some files were not shown because too many files have changed in this diff Show More