Compare commits

...

41 Commits

Author SHA1 Message Date
Diego Devesa a2e0088d92 Revert "ggml : Leverage the existing GGML_F32_VEC helpers to vectorize ggml_v…" (#16723)
This reverts commit 19a5a3edfd.
2025-10-22 20:20:55 +02:00
Pascal 9b9201f65a webui: introduce OpenAI-compatible model selector in JSON payload (#16562)
* webui: introduce OpenAI-compatible model selector in JSON payload

* webui: restore OpenAI-Compatible model source of truth and unify metadata capture

This change re-establishes a single, reliable source of truth for the active model:
fully aligned with the OpenAI-Compat API behavior

It introduces a unified metadata flow that captures the model field from both
streaming and non-streaming responses, wiring a new onModel callback through ChatService
The model name is now resolved directly from the API payload rather than relying on
server /props or UI assumptions

ChatStore records and persists the resolved model for each assistant message during
streaming, ensuring consistency across the UI and database
Type definitions for API and settings were also extended to include model metadata
and the onModel callback, completing the alignment with OpenAI-Compat semantics

* webui: address review feedback from allozaur

* webui: move model selector into ChatForm (idea by @allozaur)

* webui: make model selector more subtle and integrated into ChatForm

* webui: replaced the Flowbite selector with a native Svelte dropdown

* webui: add developer setting to toggle the chat model selector

* webui: address review feedback from allozaur

Normalized streamed model names during chat updates
by trimming input and removing directory components before saving
or persisting them, so the conversation UI shows only the filename

Forced model names within the chat form selector dropdown to render as
a single-line, truncated entry with a tooltip revealing the full name

* webui: toggle displayed model source for legacy vs OpenAI-Compat modes

When the selector is disabled, it falls back to the active server model name from /props

When the model selector is enabled, the displayed model comes from the message metadata
(the one explicitly selected and sent in the request)

* Update tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/constants/localstorage-keys.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormModelSelector.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/services/chat.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/services/chat.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* webui: refactor model selector and persistence helpers

- Replace inline portal and event listeners with proper Svelte bindings
- Introduce 'persisted' store helper for localStorage sync without runes
- Extract 'normalizeModelName' utils + Vitest coverage
- Simplify ChatFormModelSelector structure and cleanup logic

Replaced the persisted store helper's use of '$state/$effect' runes with
a plain TS implementation to prevent orphaned effect runtime errors
outside component context

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* webui: document normalizeModelName usage with inline examples

* Update tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormModelSelector.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/stores/models.svelte.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/stores/models.svelte.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* webui: extract ModelOption type into dedicated models.d.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* webui: refine ChatMessageAssistant displayedModel source logic

* webui: stabilize dropdown, simplify model extraction, and init assistant model field

* chore: update webui static build

* Update tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* chore: npm format, update webui static build

* webui: align sidebar trigger position, remove z-index glitch

* chore: update webui build output

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2025-10-22 16:58:23 +02:00
sirus20x6 19a5a3edfd ggml : Leverage the existing GGML_F32_VEC helpers to vectorize ggml_vec_set_f32 for faster fills (#16522)
* Leverage the existing GGML_F32_VEC helpers to broadcast the fill value across SIMD registers and store in vector-sized chunks, while retaining the scalar tail for leftover elements and non-SIMD builds.

* Vectorize additional f32 helper loops

* Normalize f32 helper tails for ggml vec ops

---------

Co-authored-by: Aaron <shelhamer.aaron@gmail.com>
2025-10-22 12:14:14 +02:00
Acly d8eaa26e4d tests : fix test-thread-safety when compiling with multiple backends (#16699)
* run one test per backend/device (even if it's the same device)
2025-10-22 12:01:22 +02:00
Aman Gupta 9285325ce0 CUDA: fix bug in topk-moe softmax (#16711) 2025-10-22 12:33:08 +08:00
Aman Gupta 03792ad936 CUDA: topk-moe: add optional parameter for gpt-oss (#16649) 2025-10-21 22:40:38 +08:00
Johannes Gäßler 51d1a8c997 CUDA: better error for FA kernel with 0 occupancy (#16643) 2025-10-21 15:27:53 +02:00
Aman Gupta 4926419c4d ggml: add ggml_can_fuse_subgraph (#16662)
* ggml: add ggml_can_fuse_subgraph

* ggml-cuda: use ggml_can_fuse_subgraph for topk-moe

* format

* 1. remove inputs from signature as they are transient nodes
2. add check for views: view_src should be part of the subgraph

* - combine check into one loop
- check all view_src parents
- other minor review comments

* remove redudant if test

* - rename and other minor review comments

* add assert about count < 32
2025-10-21 16:43:14 +08:00
lhez 6ea37f5739 opencl: fix warnings and clean up profiling (#16688)
* opencl: remove unused headers, fix warnings

* opencl: clean up profiling, only keep kernel time
2025-10-20 22:26:17 -07:00
Jeff Bolz fb349848f3 vulkan: Handle FA with all -inf mask values (#16447) 2025-10-20 22:16:08 -05:00
YehuditE 6de8ed7519 sycl : add PAD_REFLECT_D1 operator support (#16145)
* sycl: add PAD_REFLECT_D1 operator support

* docs(ops): regenerate docs/ops.md

* remove trailing whitespaces

* style: fix editorconfig issues — trim trailing spaces and normalize EOLs

* fix: move PAD_REFLECT_1D case outside of fall-through block
2025-10-21 00:21:12 +02:00
Sigbjørn Skjæret 84bf3c6778 model : add BailingMoeV2 support (#16063)
* add BailingMoeV2 support

* update llm types

* undo

* undo

* update llm types

* add model collection link

* update

* almost working

* correct group selection and rename n_group_exp

* avoid large top_k and use argmax instead for now

if we had something like argmax2 that would be equivalent, but this works fine until then

* poke

* skip group selection when there are no tokens

* fix 1T conversion

* hopefully fixed expert group selection

third time's the charm?

* make expert group selection generally available

The new LLaDA2Moe model uses this method too, make it generally available regardless of architecture.

* allow n_expert_groups to be 1 (Kimi K2)

* address review suggestions
2025-10-20 21:38:20 +02:00
Aleksander Grygier c9c1972e2c Handle legacy 'context' attachments (#16687) 2025-10-20 19:49:02 +02:00
Diego Devesa b617cfd289 ggml-alloc : fix leak when reusing a tensor with a larger size (#16679) 2025-10-20 14:53:50 +02:00
Aleksander Grygier 79068501fa Prevent premature submission on IME input (#16673)
* fix: Prevent premature submission on IME input

* chore: update webui static build

* refactor: Put IME completion checker in a helper function and add checking for `KeyboardEvent.eventKey === 229`

* chore: update webui static build

* chore: update webui static build

* chore: update webui static build
2025-10-20 14:21:12 +02:00
Aleksander Grygier 0e4a0cf2fa Import/Export UX improvements (#16619)
* webui : added download action (#13552)

* webui : import and export (for all conversations)

* webui : fixed download-format, import of one conversation

* webui : add ExportedConversations type for chat import/export

* feat: Update naming & order

* chore: Linting

* feat: Import/Export UX improvements

* chore: update webui build output

* feat: Update UI placement of Import/Export tab in Chat Settings Dialog

* refactor: Cleanup

chore: update webui build output

* feat: Enable shift-click multiple conversation items selection

* chore: update webui static build

* chore: update webui static build

---------

Co-authored-by: Sascha Rogmann <github@rogmann.org>
2025-10-20 13:29:14 +02:00
Aleksander Grygier 13f2cfad41 Enable per-conversation loading states to allow having parallel conversations (#16327)
* feat: Per-conversation loading states and tracking streaming stats

* chore: update webui build output

* refactor: Chat state management

Consolidates loading state management by using a global `isLoading` store synchronized with individual conversation states.

This change ensures proper reactivity and avoids potential race conditions when updating the UI based on the loading status of different conversations. It also improves the accuracy of statistics displayed.

Additionally, slots service methods are updated to use conversation IDs for per-conversation state management, avoiding global state pollution.

* feat: Adds loading indicator to conversation items

* chore: update webui build output

* fix: Fix aborting chat streaming

Improves the chat stream abortion process by ensuring that partial responses are saved before the abort signal is sent.

This avoids a race condition where the onError callback could clear the streaming state before the partial response is saved. Additionally, the stream reading loop and callbacks are now checked for abort signals to prevent further processing after abortion.

* refactor: Remove redundant comments

* chore: build webui static output

* refactor: Cleanup

* chore: update webui build output

* chore: update webui build output

* fix: Conversation loading indicator for regenerating messages

* chore: update webui static build

* feat: Improve configuration

* feat: Install `http-server` as dev dependency to not need to rely on `npx` in CI
2025-10-20 12:41:13 +02:00
takuya kodama 06332e2867 llama-batch: fix build fails with -Werror=missing-braces (#16614)
## Why it failed

When compiling with strict compiler flags (-Wmissing-braces -Werror=missing-braces),
the build fails with the following error:

```
cmake \
  -S . \
  -B ../llama.cpp.build \
  --preset=x64-linux-gcc-debug \
  -DCMAKE_INSTALL_PREFIX=/tmp/local \
  -DCMAKE_CXX_FLAGS="-Wmissing-braces -Werror=missing-braces" && \
cmake --build ../llama.cpp.build/
...
In file included from /home/otegami/work/cpp/llama.cpp/src/llama-graph.h:4,
                 from /home/otegami/work/cpp/llama.cpp/src/llama-model.h:5,
                 from /home/otegami/work/cpp/llama.cpp/src/llama.cpp:8:
/home/otegami/work/cpp/llama.cpp/src/llama-batch.h:126:48: error: missing braces around initializer for 'std::__array_traits<int, 1>::_Type' {aka 'int [1]'} [-Werror=missing-braces]
  126 |     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
      |                                                ^
cc1plus: some warnings being treated as errors
```

The issue is that std::array initialization requires double braces.

## How to fix

This PR changes `{ 0 }` to `{{ 0 }}` for std::array initialization.

This is part of a series of commits to fix missing braces warnings across the codebase.
- src/llama-batch.h <- This PR is here.
- src/llama-context.cpp
- tests/test-backend-ops.cpp
- tests/test-gguf.cpp
- tools/mtmd/clip.cpp

Benefits:
- std::array is a struct containing a C-style array, requiring nested braces
- Enables stricter compiler warnings to catch potential issues
2025-10-20 11:27:09 +03:00
Ron Evans 72d53e6c4d readme: update bindings (#16651)
Signed-off-by: deadprogram <ron@hybridgroup.com>
2025-10-20 11:20:04 +03:00
safranowith 2330de7b84 SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators (#16613)
* SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators

Clean up unrelated changes from previous commit

* Chore: remove empty lines and fix indentation

* Clean up: remove leftover blank lines and fix spacing

* chore: fix trailing whitespace and ensure final newline

* Cleanup: remove redundant declarations already defined in header

* Sync docs/ops.md with updated backend operation support

* docs: update ops.md after rebase

* docs: update ops.md - Vulkan supports SSM_CONV and SSM_SCAN
2025-10-20 11:08:32 +03:00
takuya kodama 7062dd8460 llama-context: only warn on pooling_type when user specified (#16674)
The unexpeced pooling_type warning was incorrectly shown when users did not
specify the --pooling-type parameter. In this case, the parameter
defaults to `LLAMA_POOLING_TYPE_UNSPECIFIED (-1)`, and the code
automatically applies the model's default pooling type.

Example of spurious warning:
```
$ llama-embedding -hf ggml-org/bge-m3-Q8_0-GGUF -p "hello"
...
llama_init_from_model: model default pooling_type is [2], but [-1] was specified
...
```

This fix ensures the warning only appears when users explicitly specify
a pooling type that differs from the model's default (e.g., using
--pooling-type mean on a model that expects CLS pooling).
2025-10-20 10:44:21 +03:00
Giuseppe Scrivano 0398752dd4 model : add Granite Hybrid types (#16635)
add Granite 4 models mapping their embedding dimensions to the # of
parameters.

Information taken from https://huggingface.co/ibm-granite/granite-4.0-h-tiny

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
2025-10-19 23:54:31 +02:00
Aaron Teo 4f73d0a951 ci : fix binaries release failure for s390x (binaries may not work yet) (#16664)
* devops: initial patch

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* devops: forgot the z15 suffix

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* devops: attempt at impl GGML_CPU_ALL_VARIANTS for s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* devops: rm baseline version

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

---------

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2025-10-19 23:06:39 +02:00
Sigbjørn Skjæret cec5edbcae ci : avoid manual updates of docs/ops.md (#16663) 2025-10-19 14:03:25 +02:00
Aaron Teo fcb235b466 ci: include s390x release binaries (#16648)
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2025-10-19 18:37:47 +08:00
Aman Gupta 55754bebd5 CODEOWNERS: update for ggml-cuda/mmf (#16660) 2025-10-19 10:37:12 +03:00
Johannes Gäßler ee09828cb0 HIP: fix GPU_TARGETS (#16642) 2025-10-18 14:47:32 +02:00
Jeff Bolz e56abd2098 vulkan: Implement topk_moe fused shader, ported from CUDA (#16641)
This is similar to the CUDA shader from #16130, but doesn't use shared memory
and handles different subgroup sizes.
2025-10-18 12:22:57 +02:00
Aman Gupta 38355c6c8e CUDA: use registers instead of smem in topk-moe (#16647)
Uses the technique used in the vulkan PR #16641. Neat trick!
2025-10-18 11:52:53 +02:00
Shawn Gu 81387858f1 opencl: transposed gemm/gemv moe kernel with mxfp4,f32 (#16602)
* opencl: transposed gemm/gemv moe kernel with mxfp4,f32

* add restore kernel for moe transpose

* fix trailing whitespaces

* resolve compilation warnings
2025-10-17 17:55:32 -07:00
Johannes Gäßler 66b0dbcb2d llama-model: fix insonsistent ctxs <-> bufs order (#16581) 2025-10-17 17:41:09 +02:00
Radoslav Gerganov 41386cf365 rpc : report actual free memory (#16616)
* rpc : report actual free memory

Start reporting the free memory on every device instead of using
fixed values. Now llama-cli users can get a nice memory breakdown
when using RPC devices.

* drop --mem in rpc-server
2025-10-17 18:02:52 +03:00
Giuseppe Scrivano 3d4e86bbeb vulkan: Add State Space Model (SSM) Operations Support (#16463)
* vulkan: implement SSM scan operation

Add State Space Model scan operation to the Vulkan backend.

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>

* vulkan: implement SSM conv operation

Add State Space Model conv operation to the Vulkan backend.

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>

---------

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
2025-10-17 14:23:47 +02:00
muggle-stack 342c728d03 ggml : fix SpaceMit IME array out-of-bounds in task assignment (#16629)
Fix incorrect task-to-batch index calculation in the quantization phase.

The bug caused out-of-bounds access to qnbitgemm_args array when
compute_idx exceeded per_gemm_block_count_m, leading to invalid
pointer dereferences and SIGBUS errors.

Correctly map tasks to batches by dividing compute_idx by
per_gemm_block_count_m instead of block_size_m.

Example:
  batch_feature=1, gemm_m=30, block_size_m=4
  per_gemm_block_count_m = 8, task_count = 8

  Old: gemm_idx = 4/4 = 1 (out of bounds  New: gemm_idx = 4/8 = 0 (correct)

Tested on SpaceMit K1 RISC-V64 with qwen2.5:0.5b model.

Co-authored-by: muggle <mingjun.rong@spacemit.com>
2025-10-17 13:01:23 +03:00
Pascal ababae7e1e webui: reorganize settings layout (#16607)
* webui: reorganize settings layout

* chore: update webui build output

* fix: remove unused variable

* chore: update webui build output
2025-10-17 10:35:03 +02:00
Jeff Bolz b19491599d vulkan: fix debug build (add_rms_len/data not found) (#16624) 2025-10-17 09:31:04 +02:00
Ilia Ilmer 9ad4f1931e metal : add CONV_TRANSPOSE_2D (#16542)
* initial: headers and metal-device.cpp updates

* adding conv_transpose_2d

* fix type

* fix type: int32->int64

* Update ggml/src/ggml-metal/ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml-metal/ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml-metal/ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* add checks for src[0] and src[1]; add type checks

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* add more tests, add optimization to threading

* add dynamic memory allocation in metal

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-17 09:33:58 +03:00
Olivier Chafik 79967ec596 grammar : use int64_t to avoid int overflows in int schema to grammar conversion logic (#16626) 2025-10-17 08:59:31 +03:00
GittyBurstein ceff6bb253 SYCL SET operator optimized for F32 tensors (#16350)
* SYCL/SET: implement operator + wire-up; docs/ops updates; element_wise & ggml-sycl changes

* sycl(SET): re-apply post-rebase; revert manual docs/ops.md; style cleanups

* move SET op to standalone file, GPU-only implementation

* Update SYCL SET operator for F32

* ci: fix editorconfig issues (LF endings, trailing spaces, final newline)

* fixed ggml-sycl.cpp

---------

Co-authored-by: Gitty Burstein <gitty@example.com>
2025-10-17 10:36:40 +08:00
Xuan-Son Nguyen 1bb4f43380 mtmd : support home-cooked Mistral Small Omni (#14928) 2025-10-16 19:00:31 +02:00
Pascal 683fa6ba4e fix: added a normalization step for MathJax-style \[\] and \(\) delimiters (#16599)
* fix: added a normalization step for MathJax-style \[\] and \(\) delimiters

So inline and block equations are converted before KaTeX rendering,
enabling proper display of model-generated LaTeX in the WebUI

* chore: update webui build output
2025-10-16 16:28:41 +02:00
117 changed files with 5723 additions and 722 deletions
+2
View File
@@ -134,6 +134,8 @@ jobs:
include:
- build: 'x64'
os: ubuntu-22.04
- build: 's390x-z15' # z15 because our CI runners are on z15
os: ubuntu-22.04-s390x
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
# - build: 'arm64'
# os: ubuntu-22.04-arm
+2
View File
@@ -3,10 +3,12 @@ name: Update Operations Documentation
on:
push:
paths:
- 'docs/ops.md'
- 'docs/ops/**'
- 'scripts/create_ops_docs.py'
pull_request:
paths:
- 'docs/ops.md'
- 'docs/ops/**'
- 'scripts/create_ops_docs.py'
+1 -1
View File
@@ -55,7 +55,7 @@
/ggml/src/ggml-cuda/common.cuh @slaren
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
+2
View File
@@ -138,6 +138,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
#### Multimodal
@@ -187,6 +188,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift)
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
</details>
+1 -1
View File
@@ -75,7 +75,7 @@ if [ ! -z ${GG_BUILD_ROCM} ]; then
exit 1
fi
CMAKE_EXTRA="${CMAKE_EXTRA} -DAMDGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
CMAKE_EXTRA="${CMAKE_EXTRA} -DGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
fi
if [ ! -z ${GG_BUILD_SYCL} ]; then
+12 -12
View File
@@ -41,9 +41,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int>::min();
auto has_max = max_value != std::numeric_limits<int>::max();
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int64_t>::min();
auto has_max = max_value != std::numeric_limits<int64_t>::max();
auto digit_range = [&](char from, char to) {
out << "[";
@@ -159,7 +159,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
if (has_min) {
if (min_value < 0) {
out << "\"-\" (";
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
out << ") | [0] | [1-9] ";
more_digits(0, decimals_left - 1);
} else if (min_value == 0) {
@@ -194,7 +194,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
}
digit_range(c, c);
out << " (";
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
out << ")";
if (c < '9') {
out << " | ";
@@ -216,7 +216,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
} else {
out << "\"-\" (";
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
out << ")";
}
return;
@@ -925,17 +925,17 @@ public:
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
int min_value = std::numeric_limits<int>::min();
int max_value = std::numeric_limits<int>::max();
int64_t min_value = std::numeric_limits<int64_t>::min();
int64_t max_value = std::numeric_limits<int64_t>::max();
if (schema.contains("minimum")) {
min_value = schema["minimum"].get<int>();
min_value = schema["minimum"].get<int64_t>();
} else if (schema.contains("exclusiveMinimum")) {
min_value = schema["exclusiveMinimum"].get<int>() + 1;
min_value = schema["exclusiveMinimum"].get<int64_t>() + 1;
}
if (schema.contains("maximum")) {
max_value = schema["maximum"].get<int>();
max_value = schema["maximum"].get<int64_t>();
} else if (schema.contains("exclusiveMaximum")) {
max_value = schema["exclusiveMaximum"].get<int>() - 1;
max_value = schema["exclusiveMaximum"].get<int64_t>() - 1;
}
std::stringstream out;
out << "(";
+99 -2
View File
@@ -892,8 +892,8 @@ class TextModel(ModelBase):
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
res = "llada-moe"
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
res = "bailingmoe2"
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
res = "granite-docling"
@@ -8055,6 +8055,103 @@ class BailingMoeModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("BailingMoeV2ForCausalLM")
class BailingMoeV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0):
self.block_count = self.hparams["num_hidden_layers"] + nextn_layers
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if (rope_dim := hparams.get("head_dim")) is None:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
else:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_group_count(hparams["n_group"])
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
if hparams["score_function"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["score_function"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "mlp.experts" in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
tensors: list[tuple[str, Tensor]] = []
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
class GroveMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROVEMOE
+1 -1
View File
@@ -139,7 +139,7 @@ models = [
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
]
+7 -7
View File
@@ -22,7 +22,7 @@ Legend:
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
@@ -42,7 +42,7 @@ Legend:
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
@@ -72,7 +72,7 @@ Legend:
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
@@ -84,7 +84,7 @@ Legend:
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
@@ -100,8 +100,8 @@ Legend:
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
@@ -111,6 +111,6 @@ Legend:
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+18 -2
View File
@@ -31,6 +31,14 @@
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","XIELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
"SYCL0","XIELU","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
@@ -95,6 +103,14 @@
"SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","XIELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
"SYCL0","XIELU","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
@@ -9363,8 +9379,8 @@
"SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","1","yes","SYCL"
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","SYCL"
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","SYCL"
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","yes","SYCL"
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","yes","SYCL"
"SYCL0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","SYCL"
"SYCL0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","SYCL"
"SYCL0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","SYCL"
Can't render this file because it is too large.
+21 -21
View File
@@ -3263,27 +3263,27 @@
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=0","support","1","yes","Vulkan"
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=1","support","1","yes","Vulkan"
"Vulkan0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","1","yes","Vulkan"
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","1","yes","Vulkan"
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
Can't render this file because it is too large.
+1 -2
View File
@@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices,
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
+12
View File
@@ -307,6 +307,10 @@ function(ggml_add_cpu_backend_variant tag_name)
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
endif()
ggml_add_cpu_backend_variant_impl(${tag_name})
@@ -371,6 +375,14 @@ if (GGML_CPU_ALL_VARIANTS)
else()
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
else()
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
endif()
else()
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
endif()
+22
View File
@@ -598,6 +598,26 @@ static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor
return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
}
// free the extra space at the end if the new tensor is smaller
static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_tensor * node, struct ggml_tensor * parent) {
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
size_t parent_size = ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);
size_t node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
GGML_ASSERT(parent_size >= node_size);
if (parent_size > node_size) {
struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];
struct buffer_address p_addr = p_hn->addr;
p_addr.offset += node_size;
size_t extra_size = parent_size - node_size;
AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name);
ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent);
}
}
static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
GGML_ASSERT(buffer_id >= 0);
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
@@ -643,6 +663,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
hn->addr = p_hn->addr;
p_hn->allocated = false; // avoid freeing the parent
view_src_hn->allocated = false;
ggml_gallocr_free_extra_space(galloc, node, view_src);
return;
}
} else {
@@ -650,6 +671,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
hn->buffer_id = p_hn->buffer_id;
hn->addr = p_hn->addr;
p_hn->allocated = false; // avoid freeing the parent
ggml_gallocr_free_extra_space(galloc, node, parent);
return;
}
}
+36 -20
View File
@@ -466,29 +466,45 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
message(STATUS "s390x detected")
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
list(APPEND GGML_CPU_SOURCES
ggml-cpu/arch/s390/quants.c)
# TODO: Separation to determine activation of VX/VXE/VXE2
if (${S390X_M} MATCHES "8561|8562")
message(STATUS "z15 target")
list(APPEND ARCH_FLAGS -march=z15)
elseif (${S390X_M} MATCHES "3931")
message(STATUS "z16 target")
list(APPEND ARCH_FLAGS -march=z16)
elseif (${S390X_M} MATCHES "9175|9176")
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
message(STATUS "z17 target")
list(APPEND ARCH_FLAGS -march=arch15)
else()
message(STATUS "Unknown target")
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
list(APPEND ARCH_FLAGS -march=native -mtune=native)
# for native compilation
if (GGML_NATIVE)
# check machine level to determine target
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
# TODO: Separation to determine activation of VX/VXE/VXE2
if (${S390X_M} MATCHES "8561|8562")
message(STATUS "z15 target")
list(APPEND ARCH_FLAGS -march=z15)
elseif (${S390X_M} MATCHES "3931")
message(STATUS "z16 target")
list(APPEND ARCH_FLAGS -march=z16)
elseif (${S390X_M} MATCHES "9175|9176")
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
message(STATUS "z17 target")
list(APPEND ARCH_FLAGS -march=arch15)
else()
message(STATUS "Unknown target")
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
list(APPEND ARCH_FLAGS -march=native -mtune=native)
endif()
# for cross-compilation
elseif(GGML_CPU_ALL_VARIANTS)
# range through IBM z15 to z17
# NOTE: update when a new hardware level is released
foreach (ZHW RANGE 15 17)
if(DEFINED GGML_INTERNAL_Z${ZHW})
message(STATUS "z${ZHW} cross-compile target")
list(APPEND ARCH_FLAGS -march=z${ZHW})
endif()
endforeach()
endif()
if (GGML_VXE)
if (GGML_VXE OR GGML_INTERNAL_VXE)
message(STATUS "VX/VXE/VXE2 enabled")
list(APPEND ARCH_FLAGS -mvx -mzvector)
list(APPEND ARCH_DEFINITIONS GGML_VXE)
+3 -2
View File
@@ -485,8 +485,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
int32_t start = ith * task_per_thread;
int32_t end = std::min((ith + 1) * task_per_thread, task_count);
for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
int32_t gemm_idx = compute_idx / block_size_m;
int32_t m_idx = compute_idx % block_size_m * block_size_m;
int32_t gemm_idx = compute_idx / per_gemm_block_count_m;
int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
int32_t m_idx = block_idx_in_gemm * block_size_m;
const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx];
int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
+1
View File
@@ -895,6 +895,7 @@ void launch_fattn(
const dim3 block_dim(warp_size, nwarps, 1);
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
GGML_ASSERT(max_blocks_per_sm > 0);
int parallel_blocks = max_blocks_per_sm;
dim3 blocks_num;
+35 -23
View File
@@ -2818,18 +2818,15 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
#endif
//TODO: remove special case once ggml_can_fuse can handle empty nodes
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
std::initializer_list<enum ggml_op> topk_moe_ops =
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
}
if (ops.size() == topk_moe_ops_with_norm.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+8];
@@ -2838,16 +2835,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
}
if (ops.size() == topk_moe_ops.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
@@ -2855,6 +2844,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_delayed_softmax, { node_idx + 2, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
return true;
}
}
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
@@ -2948,7 +2947,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i+8];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false);
i += 8;
continue;
}
@@ -2956,11 +2956,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i+4];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i,
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 5];
ggml_tensor * ids = cgraph->nodes[i + 1];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
/*delayed_softmax*/ true);
i += 5;
continue;
}
if (node->op == GGML_OP_ADD) {
int n_fuse = 0;
ggml_op ops[8];
+121 -70
View File
@@ -4,16 +4,61 @@
#include <initializer_list>
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
template <int experts_per_thread, bool use_limit>
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
float max_val = -INFINITY;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
if (active) {
max_val = max(max_val, vals[i]);
}
}
max_val = warp_reduce_max(max_val);
float sum = 0.f;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
if (active) {
const float val = expf(vals[i] - max_val);
vals[i] = val;
sum += val;
} else {
vals[i] = 0.f;
}
}
sum = warp_reduce_sum(sum);
const float inv_sum = 1.0f / sum;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
if (active) {
vals[i] *= inv_sum;
}
}
}
/*
This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens]
1. optionally softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory
4. optionally normalize the weights
4. optionally normalize the weights or apply softmax over the selected logits
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template <int n_experts, bool with_norm>
template <int n_experts, bool with_norm, bool delayed_softmax = false>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
@@ -30,51 +75,30 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
float logits_r[experts_per_thread];
float wt[experts_per_thread];
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
const int expert = i + threadIdx.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
}
float max_val = logits_r[0];
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
if constexpr (!delayed_softmax) {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
}
max_val = warp_reduce_max(max_val);
float wt[experts_per_thread];
float tmp = 0.f;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = expf(val - max_val);
tmp += wt[i];
}
tmp = warp_reduce_sum(tmp);
const float inv_sum = 1.0f / tmp;
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
}
//at this point, each thread holds a portion of softmax,
//we do the argmax reduce over n_expert_used, each time marking
//at this point, each thread holds either a portion of the softmax distribution
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
//the expert weight as -inf to exclude from the next iteration
float wt_sum = 0.f;
extern __shared__ float data_topk_shared[];
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
float output_weights[experts_per_thread];
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
output_weights[i] = 0.f;
}
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
@@ -99,11 +123,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
wt_shared_ptr[k] = max_val;
ids[k] = max_expert;
ids[k] = max_expert;
if constexpr (with_norm) {
wt_sum += max_val;
}
@@ -114,17 +141,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
wt_sum = warp_reduce_sum(wt_sum);
const float inv_sum = 1.0f / wt_sum;
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
for (int i = 0; i < experts_per_thread; i++) {
output_weights[i] *= inv_sum;
}
}
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
weights[i] = wt_shared_ptr[i];
if constexpr (delayed_softmax) {
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
}
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
if (idx < n_expert_used) {
weights[idx] = output_weights[i];
}
}
}
template <bool with_norm>
template <bool with_norm, bool delayed_softmax = false>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
float * weights,
@@ -132,53 +167,53 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const int n_rows,
const int n_expert,
const int n_expert_used) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
cudaStream_t stream = ctx.stream();
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
switch (n_expert) {
case 1:
topk_moe_cuda<1, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 2:
topk_moe_cuda<2, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 4:
topk_moe_cuda<4, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 8:
topk_moe_cuda<8, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 16:
topk_moe_cuda<16, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 32:
topk_moe_cuda<32, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 64:
topk_moe_cuda<64, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 128:
topk_moe_cuda<128, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 256:
topk_moe_cuda<256, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 512:
topk_moe_cuda<512, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
default:
GGML_ASSERT(false && "fatal error");
@@ -190,7 +225,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm) {
const bool with_norm,
const bool delayed_softmax) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -198,7 +234,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1];
const float * logits_d = (const float *) logits->src[0]->data;
const float * logits_d = (const float *) logits->data;
float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data;
@@ -209,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
if (with_norm) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
} else {
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
}
}
}
@@ -242,7 +282,7 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return true;
}
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
@@ -250,8 +290,19 @@ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
GGML_ASSERT(!norm || !delayed_softmax);
if (delayed_softmax) {
return delayed_softmax_ops;
}
if (norm) {
return norm_ops;
}
return no_norm_ops;
}
+4 -3
View File
@@ -6,9 +6,10 @@
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * top_k,
const bool with_norm);
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax = false);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
+4 -2
View File
@@ -28,8 +28,10 @@ if (CXX_IS_HIPCC)
" Prefer setting the HIP compiler directly. See README for details.")
endif()
else()
# Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
endif()
cmake_minimum_required(VERSION 3.21)
+48 -2
View File
@@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
static inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) {
const struct ggml_tensor * node = cgraph->nodes[node_idx];
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
return 0;
}
return cgraph->use_counts[hash_pos];
}
// return true if the node's results are only used by N other nodes
// and can be fused into their calculations.
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
const struct ggml_tensor * node = cgraph->nodes[node_idx];
// check the use count against how many we're replacing
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
return false;
}
@@ -638,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
}
GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
const int * node_idxs,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs);
// Returns true if the subgraph formed by {node_idxs} can be fused
// checks whethers all nodes which are not part of outputs can be elided
// by checking if their num_uses are confined to the subgraph
static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
int node_idx,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs) {
GGML_ASSERT(count < 32);
if (node_idx + count > cgraph->n_nodes) {
return false;
}
int idxs[32];
for (int i = 0; i < count; ++i) {
idxs[i] = node_idx + i;
}
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
}
#ifdef __cplusplus
}
#endif
@@ -651,6 +690,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
}
inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
int start_idx,
std::initializer_list<enum ggml_op> ops,
std::initializer_list<int> outputs = {}) {
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
+25
View File
@@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
char base[256];
char name[256];
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_UPSCALE);
+1
View File
@@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
+5
View File
@@ -653,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SCALE:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_CONV_TRANSPOSE_2D:
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
+13
View File
@@ -514,6 +514,19 @@ typedef struct {
uint64_t nb1;
} ggml_metal_kargs_conv_transpose_1d;
typedef struct {
int32_t IC;
int32_t IH;
int32_t IW;
int32_t KH;
int32_t KW;
int32_t OC;
int32_t s0;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_conv_transpose_2d;
typedef struct {
uint64_t ofs0;
uint64_t ofs1;
+60
View File
@@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
} break;
case GGML_OP_CONV_TRANSPOSE_2D:
{
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
} break;
case GGML_OP_UPSCALE:
{
n_fuse = ggml_metal_op_upscale(ctx, idx);
@@ -3118,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
const int32_t IC = op->src[1]->ne[2];
const int32_t IH = op->src[1]->ne[1];
const int32_t IW = op->src[1]->ne[0];
const int32_t KH = op->src[0]->ne[1];
const int32_t KW = op->src[0]->ne[0];
const int32_t OW = op->ne[0];
const int32_t OH = op->ne[1];
const int32_t OC = op->ne[2];
ggml_metal_kargs_conv_transpose_2d args = {
/*.IC =*/ IC,
/*.IH =*/ IH,
/*.IW =*/ IW,
/*.KH =*/ KH,
/*.KW =*/ KW,
/*.OC =*/ OC,
/*.s0 =*/ s0,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
// Metal requires buffer size to be multiple of 16 bytes
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
return 1;
}
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
+1
View File
@@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
+91
View File
@@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d<half>(
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
kernel void kernel_upscale_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
+2
View File
@@ -91,6 +91,8 @@ set(GGML_OPENCL_KERNELS
mul_mv_id_q8_0_f32_flat
mul_mv_id_mxfp4_f32
mul_mv_id_mxfp4_f32_flat
gemm_moe_mxfp4_f32
gemv_moe_mxfp4_f32
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q8_0_f32_l4_lm
+210 -20
View File
@@ -15,13 +15,12 @@
#include <CL/cl.h>
#include <inttypes.h>
#include <string.h>
#include <cstddef>
#include <cstdint>
#include <atomic>
#include <fstream>
#include <limits>
#include <vector>
#include <string>
#include <cmath>
@@ -402,6 +401,7 @@ struct ggml_backend_opencl_context {
cl_program program_conv_2d_f32;
cl_program program_conv_2d_f16_f32;
cl_program program_tsembd;
cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
cl_program program_mul_mv_id_mxfp4_f32;
@@ -452,7 +452,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mat_f16_f32_tiled;
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
cl_kernel kernel_convert_block_q4_0_noshuffle;
@@ -475,6 +475,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_conv_2d_f32;
cl_kernel kernel_conv_2d_f16_f32;
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
cl_kernel kernel_mul_mv_id_mxfp4_f32;
@@ -531,25 +532,17 @@ struct ggml_backend_opencl_context {
}
// Dump a csv
float total_kernel_time = 0;
fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n");
fprintf(fperf, "op name, kernel name, exec duration (ms), global size, local size, output size\n");
for (const ProfilingInfo & info : profiling_info) {
total_kernel_time += info.cmd_duration_ns/1.e6f;
fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
info.op_name.c_str(), info.kernel_name.c_str(),
info.cmd_queued_duration_ns/1.e6f,
info.cmd_submit_duration_ns/1.e6f,
info.cmd_duration_ns/1.e6f,
info.cmd_complete_duration_ns/1.e6f,
info.cmd_total_duration_ns/1.e6f,
info.global_size[0], info.global_size[1], info.global_size[2],
info.local_size[0], info.local_size[1], info.local_size[2],
info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
}
fclose(fperf);
GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
// Dump a simple chrome trace
FILE* ftrace = fopen("cl_trace.json", "w");
if (!ftrace) {
@@ -559,14 +552,14 @@ struct ggml_backend_opencl_context {
fprintf(ftrace, "[\n");
for (const ProfilingInfo & info : profiling_info) {
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n",
info.kernel_name.c_str(), info.cmd_queued/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n",
info.kernel_name.c_str(), info.cmd_submit/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n",
info.kernel_name.c_str(), info.cmd_start/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n",
info.kernel_name.c_str(), info.cmd_end/1000);
}
fclose(ftrace);
@@ -777,6 +770,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
@@ -1991,6 +1986,42 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
GGML_LOG_CONT(".");
}
std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
" -cl-mad-enable "
" -cl-fast-relaxed-math";
// gemv_moe_mxfp4_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemv_moe_mxfp4_f32.cl.h"
};
#else
const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl");
#endif
backend_ctx->program_gemv_moe_mxfp4_f32 =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err));
GGML_LOG_CONT(".");
}
// gemm_moe_mxfp4_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemm_moe_mxfp4_f32.cl.h"
};
#else
const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl");
#endif
backend_ctx->program_gemm_moe_mxfp4_f32 =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err));
GGML_LOG_CONT(".");
}
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
GGML_LOG_CONT("\n");
}
@@ -3299,6 +3330,12 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
tensor->ne[2] == 1 && tensor->ne[3] == 1;
}
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
GGML_UNUSED(backend_ctx);
int ne01 = tensor->ne[1];
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
}
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
@@ -3601,14 +3638,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
int ne02 = tensor->ne[2];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01));
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
size_t local_work_size[3] = {64, 2, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clReleaseMemObject(data_device));
tensor->extra = extra;
return;
}
#endif
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[3] = {64, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
@@ -3624,7 +3686,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
{ extra->q }
};
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
tensor->extra = extra;
return;
@@ -3751,6 +3812,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
int ne00 = tensor->ne[0];
int ne01 = tensor->ne[1];
int ne02 = tensor->ne[2];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
size_t local_work_size[3] = {64, 2, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(
queue, data_device, CL_TRUE, offset,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
}
#endif
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
@@ -7553,6 +7641,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
const int ne21 = src2->ne[1];
const cl_ulong nb21 = src2->nb[1];
const cl_ulong nb20 = src2->nb[0];
UNUSED(nb20);
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
@@ -7692,6 +7783,105 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
break;
}
case GGML_TYPE_MXFP4: {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (use_adreno_moe_kernels(backend_ctx, src0)) {
cl_int status;
size_t local_size[3] = {64, 2, 1};
size_t global_size[3] = {64, 2, 1};
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
int tile_size = 320;
if (ne12 == 1) { // for gemv
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
// create a sub_buffer for src2
cl_buffer_region region;
region.origin = offset2;
region.size = ne20 * ne21 * sizeof(int);
buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(ne01);
global_size[1] = 4;
global_size[2] = static_cast<size_t>(ne20);
local_size[1] = 4;
} else { // for gemm
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
// preprocess router table
int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
void * host_src2 = malloc(ne21 * nb21);
CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
int total_experts = nb21 / nb20;
int out_idx = 0;
for (int i_expert = 0; i_expert < ne02; i_expert++) {
for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
for (int j = 0; j < ne21; j++) {
for (int i = 0; i < ne20; i++) {
int expert = ((int *)host_src2)[j * total_experts + i];
if (i_expert == expert) {
((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
out_idx += 4;
}
}
}
}
}
buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
CL_CHECK(status);
// set thread grid
global_size[0] = static_cast<size_t>(tile_size);
global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
}
// create a sub_buffer for src1
cl_buffer_region region;
region.origin = offset1;
region.size = ne10 * ne11 * ne12 * sizeof(float);
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
CL_CHECK(status);
// create image for src1
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
CL_CHECK(status);
// Set kernel args
int arg_idx = 0;
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
if (ne12 == 1) {
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
} else {
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
}
// launch kernel
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
// deallocate sub buffers and images
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
CL_CHECK(clReleaseMemObject(buf_src1_image));
CL_CHECK(clReleaseMemObject(buf_src2));
return;
} // else fallback to generic kernel
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
+42
View File
@@ -147,6 +147,27 @@ kernel void kernel_convert_block_mxfp4(
}
}
kernel void kernel_convert_block_mxfp4_trans(
global struct block_mxfp4 * src0,
__global uint4 * dst_q,
__global uchar * dst_e,
uint ne00,
uint ne01
) {
int i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_MXFP4;
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
global struct block_mxfp4 * b = src0 + src_blk_offset;
dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];
dst_e[dst_blk_offset] = b->e;
}
kernel void kernel_restore_block_mxfp4(
global uchar * src_q,
global half * src_e,
@@ -162,6 +183,27 @@ kernel void kernel_restore_block_mxfp4(
}
}
kernel void kernel_restore_block_mxfp4_trans(
__global uint4 * src_q,
__global uchar * src_e,
global struct block_mxfp4 * dst,
uint ne00,
uint ne01
) {
int i00 = get_global_id(1);
uint i01 = get_global_id(0);
uint i02 = get_global_id(2);
uint ne00_blk = ne00 / QK_MXFP4;
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
global struct block_mxfp4 * b = dst + dst_blk_offset;
((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
b->e = src_e[src_blk_offset];
}
//------------------------------------------------------------------------------
// block_q8_0
//------------------------------------------------------------------------------
@@ -0,0 +1,162 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_MXFP4 32
#define N_SIMDGROUP 2
#define SIMDGROUP_WIDTH 64
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemm_moe_mxfp4_f32(
__global uint4 * src0_q,
__global uchar * src0_e,
__read_only image1d_buffer_t src1,
__global ushort4 * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int tile_size
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
ushort4 router = src2[i20];
ushort expert_id = router.x;
ushort i11 = router.y;
ushort i1 = router.z;
ushort tile_id = router.w;
if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size
return;
}
uint expert_offset = expert_id * ne00 * ne01 / 32;
uint tile_offset = expert_offset + tile_id * tile_size + i01;
__private float sum = 0.0f; // each thread calculate partial sum of one output
// loop along ne00 in block granularity, skip 4 blocks every iter
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
// load one block of q
uint4 regQ = src0_q[tile_offset + ib00 * ne01];
// convert 8 fp4 to fp16
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
uint offset = i11 * ne00 / 4 + ib00 * 8;
float4 shared_y4;
shared_y4 = read_imagef(src1, (offset + 0));
float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 4));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
shared_y4 = read_imagef(src1, (offset + 1));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 5));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
shared_y4 = read_imagef(src1, (offset + 2));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 6));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
shared_y4 = read_imagef(src1, (offset + 3));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 7));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
uchar regE = src0_e[tile_offset + ib00 * ne01];
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
// if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
// if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 outputs per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + tile_id * tile_size + i1 * ne01] = sum;
}
}
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define QK_MXFP4 32
#define N_SIMDGROUP 4
#define SIMDGROUP_WIDTH 64
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
sign_b.hi = fp4x8.s0 & 0x8000;
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
ushort2 fp16_packed_a_1, fp16_packed_b_1;
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
sign_b.hi = fp4x8.s1 & 0x8000;
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
}
static inline float e8m0_to_fp32(uchar x) {
int bits;
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
return as_float(bits);
}
__attribute__((qcom_reqd_sub_group_size("half")))
__kernel void kernel_gemv_moe_mxfp4_f32(
__global uint4 * src0_q,
__global uchar * src0_e,
__read_only image1d_buffer_t src1,
__global uint * src2,
__global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne11
) {
uint i01 = get_global_id(0);
uint i20 = get_global_id(2);
uint sgid = get_local_id(1);
uint slid = get_sub_group_local_id();
uint i11 = i20 % ne11;
uint expert_id = src2[i20];
uint expert_offset = expert_id * ne00 * ne01 / 32;
__private float sum = 0.0f; // each thread calculate partial sum of one output
// loop along ne00 in block granularity, skip 4 blocks every iter
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
// load one block of q
uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01];
uint offset = i11 * ne00 / 4 + ib00 * 8;
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
float4 shared_y4;
shared_y4 = read_imagef(src1, (offset + 0));
float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 4));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
shared_y4 = read_imagef(src1, (offset + 1));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 5));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
shared_y4 = read_imagef(src1, (offset + 2));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 6));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
shared_y4 = read_imagef(src1, (offset + 3));
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
shared_y4 = read_imagef(src1, (offset + 7));
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
}
// reduction in local memory, assumes #subgroups=4
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 1 outputs per thread in subgroup 0
if (sgid == 0) {
dst = dst + (offsetd >> 2);
dst[i01 + i20 * ne01] = sum;
}
}
+24 -15
View File
@@ -939,6 +939,7 @@ public:
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
@@ -1458,6 +1459,20 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
return true;
}
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
size_t free, total;
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
ggml_backend_dev_memory(dev, &free, &total);
response.free_mem = free;
response.total_mem = total;
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
return true;
}
rpc_server::~rpc_server() {
for (auto buffer : buffers) {
ggml_backend_buffer_free(buffer);
@@ -1465,7 +1480,7 @@ rpc_server::~rpc_server() {
}
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
sockfd_t sockfd) {
rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
@@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
auto dev_id = request.device;
if (dev_id >= backends.size()) {
rpc_msg_get_device_memory_rsp response;
if (!server.get_device_memory(request, response)) {
return;
}
rpc_msg_get_device_memory_rsp response;
response.free_mem = free_mem[dev_id];
response.total_mem = total_mem[dev_id];
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
response.free_mem, response.total_mem);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
@@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
}
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices,
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
if (n_devices == 0 || devices == nullptr) {
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
return;
}
std::vector<ggml_backend_t> backends;
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION,
@@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
printf("Devices:\n");
for (size_t i = 0; i < n_devices; i++) {
auto dev = devices[i];
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
total / 1024 / 1024, free / 1024 / 1024);
auto backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
@@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
printf("Accepted client connection\n");
fflush(stdout);
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
rpc_serve_client(backends, cache_dir, client_socket->fd);
printf("Client connection closed\n");
fflush(stdout);
}
+2
View File
@@ -37,5 +37,7 @@
#include "softmax.hpp"
#include "tsembd.hpp"
#include "wkv.hpp"
#include "pad_reflect_1d.hpp"
#endif // GGML_SYCL_BACKEND_HPP
+120
View File
@@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
}
template<typename T>
static __dpct_inline__ T op_floor(T x) {
return sycl::floor(x);
}
template<typename T>
static __dpct_inline__ T op_ceil(T x) {
return sycl::ceil(x);
}
template<typename T>
static __dpct_inline__ T op_round(T x) {
return sycl::round(x);
}
template<typename T>
static __dpct_inline__ T op_trunc(T x) {
return sycl::trunc(x);
}
template<typename T>
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl:
}
}
template<typename T>
static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_floor(x[i]);
}
}
template<typename T>
static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_ceil(x[i]);
}
}
template<typename T>
static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_round(x[i]);
}
}
template<typename T>
static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_trunc(x[i]);
}
}
template<typename T>
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
@@ -897,6 +945,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
}, min_val, max_val);
}
static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
}
static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
}
static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
}
static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
}
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
@@ -1122,3 +1222,23 @@ void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
}
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_floor(ctx, dst);
}
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_ceil(ctx, dst);
}
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_round(ctx, dst);
}
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_trunc(ctx, dst);
}
+4
View File
@@ -80,6 +80,10 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+31
View File
@@ -42,6 +42,7 @@
#include "ggml-sycl/presets.hpp"
#include "ggml-sycl/gemm.hpp"
#include "ggml-sycl/set_rows.hpp"
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/sycl_hw.hpp"
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/quantize.hpp"
@@ -3619,6 +3620,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_GET_ROWS:
ggml_sycl_get_rows(ctx, dst);
break;
case GGML_OP_SET:
ggml_sycl_op_set(ctx, dst);
break;
case GGML_OP_SET_ROWS:
ggml_sycl_op_set_rows(ctx, dst);
break;
@@ -3694,6 +3698,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_UNARY_OP_ELU:
ggml_sycl_elu(ctx, dst);
break;
case GGML_UNARY_OP_FLOOR:
ggml_sycl_floor(ctx, dst);
break;
case GGML_UNARY_OP_CEIL:
ggml_sycl_ceil(ctx, dst);
break;
case GGML_UNARY_OP_ROUND:
ggml_sycl_round(ctx, dst);
break;
case GGML_UNARY_OP_TRUNC:
ggml_sycl_trunc(ctx, dst);
break;
default:
return false;
}
@@ -3728,6 +3744,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_CONCAT:
ggml_sycl_op_concat(ctx, dst);
break;
case GGML_OP_PAD_REFLECT_1D:
ggml_sycl_op_pad_reflect_1d(ctx,dst);
break;
case GGML_OP_UPSCALE:
ggml_sycl_upscale(ctx, dst);
break;
@@ -4258,6 +4277,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
#if defined (GGML_SYCL_F16)
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
#else
@@ -4331,6 +4354,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
}
case GGML_OP_SET:
return (op->type == GGML_TYPE_F32) &&
(op->src[0] && op->src[1]) &&
(op->src[0]->type == GGML_TYPE_F32) &&
(op->src[1]->type == GGML_TYPE_F32);
case GGML_OP_SET_ROWS:
{
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
@@ -4429,6 +4458,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_DIV:
case GGML_OP_REPEAT:
return true;
case GGML_OP_PAD_REFLECT_1D:
return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
+72
View File
@@ -0,0 +1,72 @@
#include "pad_reflect_1d.hpp"
void pad_reflect_1d_f32(const float* src,float* dst,
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const sycl::nd_item<3> &item_ct1){
const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
const int i1 = item_ct1.get_group(1);
const int g2 = item_ct1.get_group(2);
const int i2 = g2 % ne02;
const int i3 = g2 / ne02;
if (i0 >= p0 + ne0 + p1) return;
int t = i0 - p0;
int period = 2 * ne0 -2;
int m = t % period;
m += (m < 0) * period;
int center = ne0 -1;
int srci0 = center - abs(center - m);
int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
dst[offest_dst] = src[offest_src];
}
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
const ggml_tensor * src0 = dst->src[0];
queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *) dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];
const int64_t ne0 = src0->ne[0];
const int64_t ne00 = dst->ne[0];
const int64_t ne01 = dst->ne[1];
const int64_t ne02 = dst->ne[2];
const int64_t ne03 = dst->ne[3];
const int64_t nb00 = dst->nb[0];
const int64_t nb01 = dst->nb[1];
const int64_t nb02 = dst->nb[2];
const int64_t nb03 = dst->nb[3];
const int64_t nb0 = src0->nb[0];
const int64_t nb1 = src0->nb[1];
const int64_t nb2 = src0->nb[2];
const int64_t nb3 = src0->nb[3];
int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);
stream->parallel_for(
sycl::nd_range<3>(global,
local),
[=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32(
(const float *) src0->data, (float *) dst->data,
ne0, ne02, p0, p1,
nb0, nb1, nb2, nb3,
nb00, nb01, nb02, nb03
, item_ct1);
});
}
+8
View File
@@ -0,0 +1,8 @@
#ifndef GGML_SYCL_PAD_REFLECT_1D_HPP
#define GGML_SYCL_PAD_REFLECT_1D_HPP
#include "common.hpp"
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
#endif // GGML_SYCL_PAD_REFLECT_1D_HPP
+1
View File
@@ -31,6 +31,7 @@
#define SYCL_SQRT_BLOCK_SIZE 256
#define SYCL_SIN_BLOCK_SIZE 256
#define SYCL_SQR_BLOCK_SIZE 256
#define SYCL_SET_BLOCK_SIZE 256
#define SYCL_CPY_BLOCK_SIZE 32
#define SYCL_SCALE_BLOCK_SIZE 256
#define SYCL_CLAMP_BLOCK_SIZE 256
+73
View File
@@ -0,0 +1,73 @@
#include "presets.hpp"
#include "common.hpp"
#include "ggml.h"
#include "set.hpp"
#include <cstdint>
#include <sycl/sycl.hpp>
using namespace sycl;
// Internal function: perform element-wise set operation for each thread
inline void set_f32(const float* src, float* dst,
const int64_t ne0, const int64_t ne1,
const int64_t ne2, const int64_t ne3,
const int64_t nb[3], const int64_t src_nb[3],
const int64_t offset_elem,
const nd_item<1>& item)
{
const size_t idx = item.get_global_id(0);
const size_t total = ne0 * ne1 * ne2 * ne3;
if (idx >= total) return;
// Convert linear index to 4D indices
const size_t i3 = idx / (ne2 * ne1 * ne0);
const size_t rem = idx % (ne2 * ne1 * ne0);
const size_t i2 = rem / (ne1 * ne0);
const size_t rem2 = rem % (ne1 * ne0);
const size_t i1 = rem2 / ne0;
const size_t i0 = rem2 % ne0;
// Compute source and destination indices and copy
dst[i0 + i1*nb[0] + i2*nb[1] + i3*nb[2] + offset_elem] =
src[i0 + i1*src_nb[0] + i2*src_nb[1] + i3*src_nb[2]];
}
// Main function: prepare GPU queue and launch parallel_for
void ggml_sycl_op_set(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor* src0 = dst->src[0];
const ggml_tensor* src1 = dst->src[1];
// Ensure shapes and types are compatible
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
GGML_ASSERT(dst->type == src0->type && src0->type == src1->type && dst->type == GGML_TYPE_F32);
const int32_t* opts = (const int32_t*) dst->op_params;
const int64_t nb[3] = {opts[0]/sizeof(float), opts[1]/sizeof(float), opts[2]/sizeof(float)};
const int64_t offset_elem = opts[3] / sizeof(float);
const bool inplace = opts[4];
float* dst_ptr = (float*) dst->data;
const float* src0_ptr = (const float*) src0->data;
const float* src1_ptr = (const float*) src1->data;
queue_ptr stream = ctx.stream();
// Copy src0 to dst if not inplace
if (!inplace)
stream->memcpy(dst_ptr, src0_ptr, ggml_nbytes(dst));
const int64_t ne[4] = {src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]};
const int64_t src_nb[3] = {src1->nb[1]/sizeof(float), src1->nb[2]/sizeof(float), src1->nb[3]/sizeof(float)};
const size_t total_threads = ne[0]*ne[1]*ne[2]*ne[3];
const size_t grid_size = ((total_threads + SYCL_SET_BLOCK_SIZE - 1) / SYCL_SET_BLOCK_SIZE) * SYCL_SET_BLOCK_SIZE;
// Copy src0 to dst if not inplace
stream->parallel_for(
nd_range<1>(range<1>(grid_size), range<1>(SYCL_SET_BLOCK_SIZE)),
[=](nd_item<1> item) {
set_f32(src1_ptr, dst_ptr,
ne[0], ne[1], ne[2], ne[3],
nb, src_nb, offset_elem, item); }
);
}
+5
View File
@@ -0,0 +1,5 @@
#pragma once
#include "backend.hpp"
#include "ggml.h"
void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+481 -12
View File
@@ -385,6 +385,14 @@ enum shader_reduction_mode {
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -582,6 +590,9 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
@@ -595,6 +606,9 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
// [2] is {!norm, norm}
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
std::vector<vk_pipeline_ref> all_pipelines;
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
@@ -938,6 +952,11 @@ struct vk_op_multi_add_push_constants {
static_assert(MAX_PARAMETER_COUNT == 12);
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
};
struct vk_op_add_id_push_constants {
uint32_t ne0;
uint32_t ne1;
@@ -1087,6 +1106,19 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t C;
uint32_t H;
};
struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
uint32_t nb42, nb43, nb52, nb53;
uint32_t s_off;
uint32_t n_head, d_head, n_group, n_tok;
};
struct vk_op_ssm_conv_push_constants {
uint32_t nb01, nb02;
uint32_t nb11;
uint32_t dst_nb0, dst_nb1, dst_nb2;
uint32_t nc, ncs, nr, n_t, n_s;
};
struct vk_op_conv2d_push_constants {
uint32_t Cout;
@@ -3591,6 +3623,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -3701,6 +3738,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
}
for (auto &c : compiles) {
c.wait();
}
@@ -7983,6 +8025,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
return ctx->device->pipeline_topk_moe[idx][with_norm];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
}
@@ -8098,6 +8147,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_SSM_SCAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t d_state = src0->ne[0];
if (d_state == 128) {
return ctx->device->pipeline_ssm_scan_f32_d128;
} else if (d_state == 256) {
return ctx->device->pipeline_ssm_scan_f32_d256;
}
}
return nullptr;
case GGML_OP_SSM_CONV:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_ssm_conv_f32;
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32;
@@ -8592,6 +8656,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
break;
case GGML_OP_SSM_CONV:
{
const uint32_t nr = src0->ne[1];
const uint32_t n_t = dst->ne[1];
const uint32_t n_s = dst->ne[2];
elements = { nr, n_t, n_s };
}
break;
default:
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
break;
@@ -9038,6 +9110,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
);
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = dst->src[3];
const ggml_tensor * src4 = dst->src[4];
const ggml_tensor * src5 = dst->src[5];
GGML_ASSERT(dst->buffer != nullptr);
const uint32_t head_dim = src0->ne[1];
const uint32_t n_head = src1->ne[1];
const uint32_t n_group = src4->ne[1];
const uint32_t n_tok = src1->ne[2];
const uint32_t n_seq = src1->ne[3];
bool is_mamba2 = (src3->nb[1] == sizeof(float));
GGML_ASSERT(is_mamba2);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
const vk_op_ssm_scan_push_constants pc = {
(uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
(uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
(uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
(uint32_t)src3->nb[1],
(uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
(uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
(uint32_t)s_off,
n_head, head_dim, n_group, n_tok
};
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}
vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
if (ctx->device->uma) {
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
srcs_uma[i] = d_srcs[i] != nullptr;
}
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
dst_uma = d_D != nullptr;
}
if (!dst_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
if (!srcs_uma[i]) {
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
}
}
size_t dst_size = ggml_nbytes(dst);
size_t src_sizes[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_sizes[i] = ggml_nbytes(dst->src[i]);
}
std::array<uint32_t, 3> elements;
const int splitH = 16;
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
const uint32_t num_workgroups_y = n_seq;
elements = { num_workgroups_x, num_workgroups_y, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, pc, elements);
}
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src1->nb[1],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
(uint32_t)src1->ne[0],
(uint32_t)src0->ne[0],
(uint32_t)src0->ne[1],
(uint32_t)dst->ne[1],
(uint32_t)dst->ne[2],
}, dryrun);
}
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * x = dst->src[0];
const ggml_tensor * g = dst->src[1];
@@ -9434,6 +9617,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
}
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1];
const int n_expert_used = weights->ne[1];
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}
ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
vk_buffer d_logits = nullptr;
size_t logits_buf_offset = 0;
vk_buffer d_weights = nullptr;
size_t weights_buf_offset = 0;
vk_buffer d_ids = nullptr;
size_t ids_buf_offset = 0;
bool logits_uma = false;
bool weights_uma = false;
bool ids_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
logits_uma = d_logits != nullptr;
weights_uma = d_weights != nullptr;
ids_uma = d_ids != nullptr;
}
if (!logits_uma) {
d_logits = logits_buf_ctx->dev_buffer;
logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
GGML_ASSERT(d_logits != nullptr);
}
if (!weights_uma) {
d_weights = weights_buf_ctx->dev_buffer;
weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
GGML_ASSERT(d_weights != nullptr);
}
if (!ids_uma) {
d_ids = ids_buf_ctx->dev_buffer;
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
GGML_ASSERT(d_ids != nullptr);
}
vk_op_topk_moe_push_constants pc;
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
GGML_ASSERT(n_expert_used <= n_experts);
const uint32_t rows_per_block = 4;
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
}, pc, elements);
}
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
@@ -10870,6 +11134,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
@@ -11017,11 +11283,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
}
// Add the last fused node and all fused source nodes to the unsynchronized list.
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
ctx->unsynced_nodes_written.push_back(last_node);
// Add all fused nodes to the unsynchronized lists.
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
ctx->unsynced_nodes_written.push_back(cur_node);
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
continue;
@@ -11188,7 +11454,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_SOFT_MAX:
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
if (ctx->num_additional_fused_ops) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
} else {
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
}
break;
case GGML_OP_SOFT_MAX_BACK:
@@ -11287,6 +11557,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_SSM_SCAN:
ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_SSM_CONV:
ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
@@ -11398,6 +11678,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
@@ -11972,6 +12254,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
return true;
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
int node_idx, bool with_norm) {
if (with_norm) {
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
return false;
}
}
} else {
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
return false;
}
}
}
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];
float max_bias = op_params[1];
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
return false;
}
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
return false;
}
// Check that the nodes don't have any unexpected uses
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
// softmax is used by reshape and argsort
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
reshape1->src[0] != softmax ||
argsort->src[0] != softmax) {
return false;
}
// reshape is used by get_rows
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
get_rows->src[0] != reshape1) {
return false;
}
// argsort is used by view
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
view->src[0] != argsort) {
return false;
}
// view is written (via argsort), we can skip checking it
if (with_norm) {
// get_rows is used by reshape
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
reshape5->src[0] != get_rows) {
return false;
}
// reshape is used by sum_rows and div
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
sum_rows->src[0] != reshape5 ||
div->src[0] != reshape5) {
return false;
}
// sum_rows is used by div
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
div->src[1] != sum_rows) {
return false;
}
// div/reshape are written
if (reshape8->src[0] != div) {
return false;
}
}
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
ctx->device->disable_fusion) {
return false;
}
return true;
}
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
const ggml_tensor *first_node = cgraph->nodes[node_idx];
@@ -12047,6 +12443,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12144,6 +12544,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
}
}
@@ -12151,10 +12555,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = (submitted_nodes >= nodes_per_submit) ||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(i + ctx->num_additional_fused_ops == last_node) ||
(i + ctx->num_additional_fused_ops >= last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending);
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
if (vk_perf_logger_enabled) {
if (ctx->compute_ctx.expired()) {
@@ -12275,6 +12679,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
// Avoid reordering topk_moe_norm
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
bool is_topk_moe_norm = true;
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
is_topk_moe_norm = false;
}
}
if (is_topk_moe_norm) {
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
continue;
}
}
// First, grab the next unused node.
current_set.push_back(first_unused);
@@ -12879,6 +13302,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
return false;
}
}
if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
return false;
}
if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
return false;
}
const uint32_t d_state = op->src[0]->ne[0];
const uint32_t head_dim = op->src[0]->ne[1];
bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
if (!is_mamba2) {
return false;
}
if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
const uint32_t SPLIT_H = 16;
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
return false;
}
return true;
}
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
@@ -13223,14 +13687,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
struct ggml_context * ggml_ctx = ggml_init(iparams);
std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<size_t, GGML_MAX_SRC> src_size = {};
std::array<void *, GGML_MAX_SRC> src_buffer = {};
const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
struct ggml_tensor * tensor_clone = nullptr;
for (int i = 0; i < 6; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
@@ -13537,6 +14001,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[2]);
} else if (tensor->op == GGML_OP_ADD_ID) {
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
} else if (tensor->op == GGML_OP_SSM_SCAN) {
tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_SSM_CONV) {
tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
}
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@@ -13558,7 +14027,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
memcpy(comp_result, tensor_clone->data, comp_size);
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
for (int i = 0; i < 6; i++) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (src_buffer[i] != nullptr) {
free(src_buffer[i]);
}
@@ -345,7 +345,7 @@ void main() {
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
@@ -380,7 +380,7 @@ void main() {
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
@@ -121,7 +121,11 @@ void main() {
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
#if defined(ACC_TYPE_MAX)
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
#else
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
#endif
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
@@ -294,7 +298,7 @@ void main() {
[[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
}
O = Ldiag*O;
@@ -91,7 +91,7 @@ void main() {
L = L*ms + vs;
}
L = 1.0 / L;
L = (L == 0.0) ? 0.0 : 1.0 / L;
// D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
@@ -0,0 +1,44 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
layout(binding = 2) buffer Dst { float dst[]; };
layout(push_constant) uniform PushConstants {
uint nb01; uint nb02;
uint nb11;
uint dst_nb0; uint dst_nb1; uint dst_nb2;
uint nc; uint ncs; uint nr; uint n_t; uint n_s;
};
void main() {
const uint global_thread_id = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i3 = gl_WorkGroupID.z;
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
return;
}
const uint i1 = global_thread_id;
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
const uint src1_base = i1 * (nb11 / 4);
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
float sum = 0.0;
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
const uint src0_idx = src0_base + i0;
const uint src1_idx = src1_base + i0;
sum += src0[src0_idx] * src1[src1_idx];
}
dst[dst_idx] = sum;
}
@@ -0,0 +1,125 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.glsl"
layout(constant_id = 0) const uint D_STATE = 128;
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 2) const uint SPLIT_H = 16;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float s0[]; };
layout(binding = 1) readonly buffer Src1 { float x[]; };
layout(binding = 2) readonly buffer Src2 { float dt[]; };
layout(binding = 3) readonly buffer Src3 { float A[]; };
layout(binding = 4) readonly buffer Src4 { float B[]; };
layout(binding = 5) readonly buffer Src5 { float C[]; };
layout(binding = 6) readonly buffer Src6 { int ids[]; };
layout(binding = 7) buffer Dst { float d[]; };
layout(push_constant) uniform PushConstants {
uint nb02; uint nb03; uint nb12; uint nb13;
uint nb21; uint nb22; uint nb31;
uint nb42; uint nb43; uint nb52; uint nb53;
uint s_off;
uint n_head;
uint d_head;
uint n_group;
uint n_tok;
};
float softplus(float x) {
if (x <= 20.0) {
return log(1.0 + exp(x));
} else {
return x;
}
}
shared float stateC[SPLIT_H * D_STATE];
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
const uint seq_idx = gl_WorkGroupID.y;
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
const uint A_base_idx = (head_idx * nb31) / 4;
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint stride_x = nb12 / 4;
const uint stride_dt = nb21 / 4;
const uint stride_B = nb42 / 4;
const uint stride_C = nb52 / 4;
const uint stride_y = n_head * d_head;
float state[SPLIT_H];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
state[j] = s0[s0_base_idx + j * D_STATE + tid];
}
for (uint i = 0; i < n_tok; i++) {
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
const float dA = exp(dt_soft_plus * A[A_base_idx]);
const float B_val = B[B_base_idx + i * stride_B + tid];
const float C_val = C[C_base_idx + i * stride_C + tid];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
state[j] = (state[j] * dA) + (B_val * x_dt);
stateC[j * D_STATE + tid] = state[j] * C_val;
}
barrier();
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % (w >> 1)) +
(D_STATE * (tid / (w >> 1))) +
j * D_STATE * (D_STATE / (w >> 1));
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + (w >> 1)];
}
}
barrier();
}
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
uint lane = tid % SUBGROUP_SIZE;
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = stateC[idx];
}
}
barrier();
}
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
d[s_base_idx + j * D_STATE + tid] = state[j];
}
}
@@ -0,0 +1,139 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
if (row >= n_rows) {
return;
}
const uint logits_offset = n_experts * row;
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
float logits_r[experts_per_thread];
const float INFINITY = 1.0 / 0.0;
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + gl_LocalInvocationID.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
}
float max_val = logits_r[0];
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
}
max_val = subgroupMax(max_val);
float wt[experts_per_thread];
float tmp = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = exp(val - max_val);
tmp += wt[i];
}
tmp = subgroupAdd(tmp);
const float inv_sum = 1.0f / tmp;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
}
// at this point, each thread holds a portion of softmax,
// we do the argmax reduce over n_expert_used, each time marking
// the expert weight as -inf to exclude from the next iteration
float wt_sum = 0.f;
float output_weights[experts_per_thread];
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
[[unroll]]
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = subgroupShuffleXor(max_val, mask);
const uint expert = subgroupShuffleXor(max_expert, mask);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[ids_offset + k] = max_expert;
if (with_norm) {
wt_sum += max_val;
}
}
}
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
output_weights[i] *= inv_sum;
}
}
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
if (idx < n_expert_used) {
weights[weights_offset + idx] = output_weights[i];
}
}
}
@@ -916,6 +916,12 @@ void process_shaders() {
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
string_to_spv("topk_moe_f32", "topk_moe.comp", {});
for (auto &c : compiles) {
c.wait();
}
@@ -959,7 +965,7 @@ void write_output_files() {
}
std::string suffixes[2] = {"_f32", "_f16"};
for (auto op : {"add", "sub", "mul", "div", "add_rms"}) {
for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
+72
View File
@@ -6964,6 +6964,78 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_LOG_INFO("========================================\n");
}
static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,
const int * idxs,
int count,
const struct ggml_tensor * tensor) {
GGML_ASSERT(cgraph && idxs);
for (int i = 0; i < count; ++i) {
const int node_idx = idxs[i];
if (node_idx >= cgraph->n_nodes) {
return -1;
}
if (cgraph->nodes[node_idx] == tensor) {
return i;
}
}
return -1;
}
bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
const int * node_idxs,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs) {
GGML_ASSERT(outputs && num_outputs > 0);
for (int i = 0; i < count; ++i) {
if (node_idxs[i] >= cgraph->n_nodes) {
return false;
}
const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
if (node->op != ops[i]) {
return false;
}
if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
continue;
}
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
return false;
}
int subgraph_uses = 0;
for (int j = i + 1; j < count; ++j) {
const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
if (other_node->src[src_idx] == node) {
subgraph_uses++;
}
}
}
if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
return false;
}
// if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
struct ggml_tensor * view_src = node->view_src;
while (view_src) {
if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
return false;
}
view_src = view_src->view_src;
}
}
return true;
}
// check if node is part of the graph
static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
if (cgraph == NULL) {
+33
View File
@@ -102,6 +102,8 @@ class Keys:
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
EXPERT_GROUP_COUNT = "{arch}.expert_group_count"
EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count"
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
@@ -400,6 +402,7 @@ class MODEL_ARCH(IntEnum):
WAVTOKENIZER_DEC = auto()
PLM = auto()
BAILINGMOE = auto()
BAILINGMOE2 = auto()
DOTS1 = auto()
ARCEE = auto()
ERNIE4_5 = auto()
@@ -744,6 +747,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.PLM: "plm",
MODEL_ARCH.BAILINGMOE: "bailingmoe",
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
@@ -2533,6 +2537,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.BAILINGMOE2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.DOTS1: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
+6
View File
@@ -755,6 +755,12 @@ class GGUFWriter:
def add_expert_shared_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
def add_expert_group_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
def add_expert_group_used_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
def add_expert_weights_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
+6
View File
@@ -174,6 +174,7 @@ class TensorNameMap:
"h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
"model.layers.{bid}.self_attn.query_key_value", # persimmon
"model.layers.{bid}.attention.query_key_value", # bailingmoe2
"h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.mixer.Wqkv", # phi2
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
@@ -260,6 +261,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon
"model.layers.{bid}.attention.dense", # bailingmoe2
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
@@ -373,6 +375,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_EXP_PROBS_B: (
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
),
@@ -549,6 +552,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.attention.query_layernorm", # bailingmoe2
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"layers.{bid}.self_attn.q_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
@@ -563,6 +567,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.attention.key_layernorm", # bailingmoe2
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"layers.{bid}.self_attn.k_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
@@ -584,6 +589,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code
"model.layers.{bid}.final_layernorm", # bailingmoe2
),
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
+35
View File
@@ -85,6 +85,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
@@ -135,6 +136,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
{ LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" },
{ LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" },
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
@@ -1946,6 +1949,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_BAILINGMOE2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
},
},
{
LLM_ARCH_DOTS1,
{
+3
View File
@@ -89,6 +89,7 @@ enum llm_arch {
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_BAILINGMOE2,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
@@ -139,6 +140,8 @@ enum llm_kv {
LLM_KV_EXPERT_COUNT,
LLM_KV_EXPERT_USED_COUNT,
LLM_KV_EXPERT_SHARED_COUNT,
LLM_KV_EXPERT_GROUP_COUNT,
LLM_KV_EXPERT_GROUP_USED_COUNT,
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
+1 -1
View File
@@ -123,7 +123,7 @@ private:
uint32_t n_seq_max;
uint32_t n_outputs;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
+35 -2
View File
@@ -63,6 +63,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "bailing-think", LLM_CHAT_TEMPLATE_BAILING_THINK },
{ "bailing2", LLM_CHAT_TEMPLATE_BAILING2 },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
@@ -191,6 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("\"HUMAN\"") && tmpl_contains("<think>")) {
return LLM_CHAT_TEMPLATE_BAILING_THINK;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("<role>HUMAN</role>") && tmpl_contains("<|role_end|>")) {
return LLM_CHAT_TEMPLATE_BAILING2;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) {
@@ -644,8 +650,8 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << " Ассистент:[SEP]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
// Bailing (Ling) template
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING || tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
// Bailing (Ling/Ring) template
for (auto message : chat) {
std::string role(message->role);
@@ -658,6 +664,33 @@ int32_t llm_chat_apply_template(
ss << "<role>" << role << "</role>" << message->content;
}
if (add_ass) {
ss << "<role>ASSISTANT</role>";
if (tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
ss << "<think>";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING2) {
// Bailing2 (Ling 2.0) template
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
if (!has_system) {
ss << "<role>SYSTEM</role>detailed thinking off<|role_end|>";
}
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
role = "HUMAN";
} else {
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
}
ss << "<role>" << role << "</role>" << message->content << "<|role_end|>";
}
if (add_ass) {
ss << "<role>ASSISTANT</role>";
}
+2
View File
@@ -42,6 +42,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_BAILING_THINK,
LLM_CHAT_TEMPLATE_BAILING2,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1,
+2 -1
View File
@@ -2346,7 +2346,8 @@ llama_context * llama_init_from_model(
return nullptr;
}
if (params.pooling_type != model->hparams.pooling_type) {
if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED &&
params.pooling_type != model->hparams.pooling_type) {
//user-specified pooling-type is different from the model default
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
model->hparams.pooling_type, params.pooling_type);
+30
View File
@@ -950,6 +950,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(selection_probs, "ffn_moe_probs_biased", il);
}
// select top n_group_used expert groups
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
if (hparams.n_expert_groups > 1 && n_tokens > 0) {
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
// organize experts into n_expert_groups
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
// get top n_group_used expert groups
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
cb(expert_groups, "ffn_moe_group_topk", il);
// mask out the other groups
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
cb(selection_probs, "ffn_moe_probs_masked", il);
}
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
@@ -981,6 +1006,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);
if (arch == LLM_ARCH_BAILINGMOE2) {
weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20);
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
}
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il);
+2
View File
@@ -72,6 +72,8 @@ struct llama_hparams {
uint32_t n_ff_chexp = 0;
uint32_t n_expert_shared = 0;
uint32_t n_norm_groups = 0;
uint32_t n_expert_groups = 0;
uint32_t n_group_used = 0;
uint32_t n_group_experts = 0;
float expert_group_scale = 0.05f;
+298 -39
View File
@@ -114,9 +114,12 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
case LLM_TYPE_A13B: return "A13B";
case LLM_TYPE_7B_A1B: return "7B.A1B";
case LLM_TYPE_8B_A1B: return "8B.A1B";
case LLM_TYPE_16B_A1B: return "16B.A1B";
case LLM_TYPE_21B_A3B: return "21B.A3B";
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_100B_A6B: return "100B.A6B";
case LLM_TYPE_106B_A12B: return "106B.A12B";
case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_300B_A47B: return "300B.A47B";
@@ -421,11 +424,8 @@ struct llama_model::impl {
llama_mlocks mlock_bufs;
llama_mlocks mlock_mmaps;
// contexts where the model tensors metadata is stored
std::vector<ggml_context_ptr> ctxs;
// the model memory buffers for the tensor data
std::vector<ggml_backend_buffer_ptr> bufs;
// contexts where the model tensors metadata is stored as well ass the corresponding buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
buft_list_t cpu_buft_list;
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
@@ -483,11 +483,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
return;
}
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
@@ -503,8 +505,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
if (hparams.n_expert > 0) {
GGML_ASSERT(hparams.n_expert_used > 0);
GGML_ASSERT(hparams.n_expert_groups < hparams.n_expert);
if (hparams.n_expert_groups > 1) {
GGML_ASSERT(hparams.n_expert % hparams.n_expert_groups == 0);
GGML_ASSERT(hparams.n_group_used > 0);
GGML_ASSERT(hparams.n_group_used < hparams.n_expert_groups);
}
} else {
GGML_ASSERT(hparams.n_expert_used == 0);
GGML_ASSERT(hparams.n_expert_groups == 0);
}
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
@@ -1846,8 +1855,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
// TODO: Add llm type label (not sure this is useful)
switch (hparams.n_embd) {
case 1536: type = LLM_TYPE_7B_A1B; break;
case 2048: case 2560: type = LLM_TYPE_3B; break;
case 4096: type = LLM_TYPE_32B; break;
default: type = LLM_TYPE_UNKNOWN;
}
@@ -1888,6 +1899,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_BAILINGMOE2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 20: type = LLM_TYPE_16B_A1B; break;
case 21: type = LLM_TYPE_16B_A1B; break;
case 32: type = LLM_TYPE_100B_A6B; break;
case 33: type = LLM_TYPE_100B_A6B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_DOTS1:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -2182,7 +2216,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
max_n_tensors += n_layer*2; // duplicated rope freq tensors
const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
@@ -2197,12 +2238,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
throw std::runtime_error(format("failed to create ggml context"));
}
ctx_map[buft] = ctx;
pimpl->ctxs.emplace_back(ctx);
ctx_map.emplace(buft, ctx);
return ctx;
}
return it->second;
return it->second.get();
};
const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED;
@@ -5492,6 +5532,70 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
}
} break;
case LLM_ARCH_BAILINGMOE2:
{
const int64_t n_ff_exp = hparams.n_ff_exp;
const int64_t n_expert_shared = hparams.n_expert_shared;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2");
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2");
for (int i = 0; i < n_layer; ++i) {
int flags = 0;
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
// skip all tensors in the NextN layers
flags |= TENSOR_SKIP;
}
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers
const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared;
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
} else { // Dense layers
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
}
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags);
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags);
}
}
} break;
case LLM_ARCH_DOTS1:
{
const int64_t n_ff_exp = hparams.n_ff_exp;
@@ -6037,16 +6141,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
pimpl->mappings.reserve(ml.mappings.size());
// create the backend buffers
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_bufs;
ctx_bufs.reserve(ctx_map.size());
std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_buf_maps;
ctx_buf_maps.reserve(ctx_map.size());
// Ensure we have enough capacity for the maximum backend buffer we will potentially create
const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
pimpl->bufs.reserve(n_max_backend_buffer);
pimpl->ctxs_bufs.reserve(n_max_backend_buffer);
for (auto & it : ctx_map) {
ggml_backend_buffer_type_t buft = it.first;
ggml_context * ctx = it.second;
for (auto & [buft, ctx_ptr] : ctx_map) {
ggml_context * ctx = ctx_ptr.get();
// skip contexts without tensors
if (ggml_get_first_tensor(ctx) == nullptr) {
@@ -6070,6 +6173,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
ggml_backend_buffer_t buf = nullptr;
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
// only the mmap region containing the tensors in the model is mapped to the backend buffer
@@ -6082,20 +6186,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
continue;
}
const size_t max_size = ggml_get_max_tensor_size(ctx);
ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
if (buf == nullptr) {
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
}
pimpl->bufs.emplace_back(buf);
buf_map.emplace(idx, buf);
}
}
else {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (buf == nullptr) {
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
}
pimpl->bufs.emplace_back(buf);
if (use_mlock && ggml_backend_buffer_is_host(buf)) {
pimpl->mlock_bufs.emplace_back(new llama_mlock);
auto & mlock_buf = pimpl->mlock_bufs.back();
@@ -6106,10 +6208,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
buf_map.emplace(idx, buf);
}
}
if (pimpl->bufs.empty()) {
throw std::runtime_error("failed to allocate buffer");
}
pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf);
for (auto & buf : buf_map) {
// indicate that this buffer contains weights
@@ -6117,7 +6216,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
}
ctx_bufs.emplace_back(ctx, buf_map);
ctx_buf_maps.emplace_back(ctx, buf_map);
}
if (llama_supports_gpu_offload()) {
@@ -6135,22 +6234,20 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
// print memory requirements per buffer type
for (auto & buf : pimpl->bufs) {
for (auto & [_, buf] : pimpl->ctxs_bufs) {
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
}
// populate tensors_by_name
for (auto & ctx : pimpl->ctxs) {
for (auto & [ctx, _] : pimpl->ctxs_bufs) {
for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
tensors_by_name.emplace_back(ggml_get_name(cur), cur);
}
}
// load tensor data
for (auto & it : ctx_bufs) {
ggml_context * ctx = it.first;
auto & bufs = it.second;
if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
for (auto & [ctx, buf_map] : ctx_buf_maps) {
if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
return false;
}
}
@@ -6190,8 +6287,8 @@ size_t llama_model::n_devices() const {
std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : pimpl->bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
for (const auto & [_, buf] : pimpl->ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
}
return ret;
}
@@ -6354,6 +6451,19 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
}
if (arch == LLM_ARCH_BAILINGMOE2) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers);
}
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
@@ -17043,6 +17153,150 @@ struct llm_build_bailingmoe : public llm_graph_context {
}
};
struct llm_build_bailingmoe2 : public llm_graph_context {
llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self_attention
{
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA);
cb(sa_out, "sa_out", il);
// MoE branch
cur = build_norm(sa_out,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
ggml_tensor * moe_out =
build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(moe_out, "ffn_moe_out", il);
{
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
}
}
cur = ggml_add(ctx0, cur, sa_out);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_dots1 : public llm_graph_context {
llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -19839,6 +20093,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_bailingmoe>(*this, params);
} break;
case LLM_ARCH_BAILINGMOE2:
{
llm = std::make_unique<llm_build_bailingmoe2>(*this, params);
} break;
case LLM_ARCH_SEED_OSS:
{
llm = std::make_unique<llm_build_seed_oss>(*this, params);
@@ -20105,6 +20363,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_EXAONE:
case LLM_ARCH_EXAONE4:
case LLM_ARCH_MINICPM3:
case LLM_ARCH_BAILINGMOE2:
case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE:
case LLM_ARCH_OPENAI_MOE:
+3
View File
@@ -107,9 +107,12 @@ enum llm_type {
LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_A13B,
LLM_TYPE_7B_A1B,
LLM_TYPE_8B_A1B, // lfm2moe
LLM_TYPE_16B_A1B,
LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_30B_A3B,
LLM_TYPE_100B_A6B,
LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_235B_A22B,
LLM_TYPE_300B_A47B, // Ernie MoE big
+1
View File
@@ -1968,6 +1968,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
clean_spaces = false;
} else if (
tokenizer_pre == "bailingmoe" ||
tokenizer_pre == "bailingmoe2" ||
tokenizer_pre == "llada-moe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
clean_spaces = false;
+156 -6
View File
@@ -3759,6 +3759,130 @@ struct test_clamp : public test_case {
}
};
// GGML_OP_FLOOR
struct test_floor : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_floor(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 2, 2, 2})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_floor(ctx, a);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.0f, 10.0f);
}
}
};
// GGML_OP_CEIL
struct test_ceil : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_ceil(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 2, 2, 2})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_ceil(ctx, a);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.0f, 10.0f);
}
}
};
// GGML_OP_ROUND
struct test_round : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_round(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 2, 2, 2})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_round(ctx, a);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.0f, 10.0f);
}
}
};
// GGML_OP_TRUNC
struct test_trunc : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_trunc(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 2, 2, 2})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_trunc(ctx, a);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.0f, 10.0f);
}
}
};
// GGML_OP_DIAG_MASK_INF
struct test_diag_mask_inf : public test_case {
const ggml_type type;
@@ -4545,14 +4669,21 @@ struct test_topk_moe: public test_case {
const std::array<int64_t, 4> ne;
const int n_expert_used;
const bool with_norm;
test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false)
: ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) {
const bool delayed_softmax;
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
int n_expert_used = 1,
bool with_norm = false,
bool delayed_softmax = false) :
ne(ne),
n_expert_used(n_expert_used),
with_norm(with_norm),
delayed_softmax(delayed_softmax) {
GGML_ASSERT(n_expert_used <= ne[0]);
GGML_ASSERT(!(with_norm && delayed_softmax));
}
std::string vars() override {
return VARS_TO_STR3(ne, n_expert_used, with_norm);
}
std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@@ -4566,11 +4697,17 @@ struct test_topk_moe: public test_case {
const int n_tokens = ne[1];
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = ggml_soft_max(ctx, logits);
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
if (delayed_softmax) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
}
if (with_norm) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
@@ -6585,6 +6722,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cos (type));
test_cases.emplace_back(new test_clamp (type));
test_cases.emplace_back(new test_leaky_relu(type));
test_cases.emplace_back(new test_floor (type));
test_cases.emplace_back(new test_ceil (type));
test_cases.emplace_back(new test_round (type));
test_cases.emplace_back(new test_trunc (type));
test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_log (type, {7, 1, 5, 3}));
@@ -6592,6 +6733,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
}
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
@@ -6843,6 +6988,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
}
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));
@@ -6989,6 +7137,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
+24
View File
@@ -301,6 +301,30 @@ static void test_simple_grammar() {
"0123",
}
);
test_schema(
"min 1 max 900719925474091",
// Schema
R"""({
"type": "integer",
"exclusiveMinimum": 0,
"maximum": 900719925474091
})""",
// Passing strings
{
"1",
"2",
"10",
"900719925474090",
"900719925474091",
},
// Failing strings
{
"0",
"01",
"900719925474092",
"9007199254740910",
}
);
test_schema(
"min -1 max 1",
R"""({
+6 -4
View File
@@ -3,6 +3,7 @@
// - Creates n_parallel (--parallel) contexts per model
// - Runs inference in parallel on each context
#include <array>
#include <thread>
#include <vector>
#include <atomic>
@@ -38,13 +39,14 @@ int main(int argc, char ** argv) {
cparams.n_seq_max = 1;
int dev_count = ggml_backend_dev_count();
int gpu_dev_count = 0;
std::vector<std::array<ggml_backend_dev_t, 2>> gpus;
for (int i = 0; i < dev_count; ++i) {
auto * dev = ggml_backend_dev_get(i);
if (dev && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
gpu_dev_count++;
gpus.push_back({dev, nullptr});
}
}
const int gpu_dev_count = (int)gpus.size();
const int num_models = gpu_dev_count + 1 + 1; // GPUs + 1 CPU model + 1 layer split
//const int num_models = std::max(1, gpu_dev_count);
const int num_contexts = std::max(1, params.n_parallel);
@@ -58,12 +60,12 @@ int main(int argc, char ** argv) {
if (m < gpu_dev_count) {
mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
mparams.main_gpu = m;
mparams.devices = gpus[m].data();
} else if (m == gpu_dev_count) {
mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
mparams.main_gpu = -1; // CPU model
} else {
mparams.split_mode = LLAMA_SPLIT_MODE_LAYER;;
mparams.split_mode = LLAMA_SPLIT_MODE_LAYER;
}
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
+2
View File
@@ -30,6 +30,7 @@
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
// vision-specific
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
#define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size"
@@ -48,6 +49,7 @@
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"
// audio-specific
#define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
+15 -3
View File
@@ -2221,15 +2221,27 @@ struct clip_model_loader {
// projector type
std::string proj_type;
{
// default key
get_string(KEY_PROJ_TYPE, proj_type, false);
if (!proj_type.empty()) {
model.proj_type = clip_projector_type_from_string(proj_type);
// for models with mixed modalities
if (proj_type.empty()) {
if (modality == CLIP_MODALITY_VISION) {
get_string(KEY_VISION_PROJ_TYPE, proj_type, false);
} else if (modality == CLIP_MODALITY_AUDIO) {
get_string(KEY_AUDIO_PROJ_TYPE, proj_type, false);
} else {
GGML_ABORT("unknown modality");
}
}
model.proj_type = clip_projector_type_from_string(proj_type);
if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
}
// correct arch for multimodal models
// correct arch for multimodal models (legacy method)
if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
model.proj_type = modality == CLIP_MODALITY_VISION
? PROJECTOR_TYPE_QWEN25VL
+1 -33
View File
@@ -137,7 +137,6 @@ struct rpc_server_params {
bool use_cache = false;
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
std::vector<std::string> devices;
std::vector<size_t> dev_mem;
};
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
@@ -148,7 +147,6 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
fprintf(stderr, " -c, --cache enable local file cache\n");
fprintf(stderr, "\n");
}
@@ -197,23 +195,6 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
}
} else if (arg == "-c" || arg == "--cache") {
params.use_cache = true;
} else if (arg == "-m" || arg == "--mem") {
if (++i >= argc) {
return false;
}
const std::regex regex{ R"([,/]+)" };
std::string mem_str = argv[i];
std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
std::sregex_token_iterator end;
for ( ; iter != end; ++iter) {
try {
size_t mem = std::stoul(*iter) * 1024 * 1024;
params.dev_mem.push_back(mem);
} catch (const std::exception & ) {
fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
return false;
}
}
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv, params);
exit(0);
@@ -293,18 +274,6 @@ int main(int argc, char * argv[]) {
return 1;
}
std::string endpoint = params.host + ":" + std::to_string(params.port);
std::vector<size_t> free_mem, total_mem;
for (size_t i = 0; i < devices.size(); i++) {
if (i < params.dev_mem.size()) {
free_mem.push_back(params.dev_mem[i]);
total_mem.push_back(params.dev_mem[i]);
} else {
size_t free, total;
ggml_backend_dev_memory(devices[i], &free, &total);
free_mem.push_back(free);
total_mem.push_back(total);
}
}
const char * cache_dir = nullptr;
std::string cache_dir_str;
if (params.use_cache) {
@@ -328,7 +297,6 @@ int main(int argc, char * argv[]) {
return 1;
}
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
devices.data(), free_mem.data(), total_mem.data());
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data());
return 0;
}
Binary file not shown.
+527
View File
@@ -50,6 +50,7 @@
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
"http-server": "^14.1.1",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3",
"playwright": "^1.53.0",
@@ -2979,6 +2980,13 @@
"node": ">=4"
}
},
"node_modules/async": {
"version": "3.2.6",
"resolved": "https://registry.npmjs.org/async/-/async-3.2.6.tgz",
"integrity": "sha512-htCUDlxyyCLMgaM3xXg0C0LW2xqfuQ6p05pCEIsXuyQ+a1koYKTuBMzRNwmybfLgvJDMd0r1LTn4+E0Ti6C2AA==",
"dev": true,
"license": "MIT"
},
"node_modules/axe-core": {
"version": "4.10.3",
"resolved": "https://registry.npmjs.org/axe-core/-/axe-core-4.10.3.tgz",
@@ -3015,6 +3023,19 @@
"dev": true,
"license": "MIT"
},
"node_modules/basic-auth": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/basic-auth/-/basic-auth-2.0.1.tgz",
"integrity": "sha512-NF+epuEdnUYVlGuhaxbbq+dvJttwLnGY+YixlXlME5KpQ5W3CnXA5cVTneY3SPbPDRkcjMbifrwmFYcClgOZeg==",
"dev": true,
"license": "MIT",
"dependencies": {
"safe-buffer": "5.1.2"
},
"engines": {
"node": ">= 0.8"
}
},
"node_modules/better-opn": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/better-opn/-/better-opn-3.0.2.tgz",
@@ -3125,6 +3146,37 @@
"node": ">=8"
}
},
"node_modules/call-bind-apply-helpers": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz",
"integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0",
"function-bind": "^1.1.2"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/call-bound": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz",
"integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bind-apply-helpers": "^1.0.2",
"get-intrinsic": "^1.3.0"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/callsites": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
@@ -3335,6 +3387,16 @@
"node": ">= 0.6"
}
},
"node_modules/corser": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/corser/-/corser-2.0.1.tgz",
"integrity": "sha512-utCYNzRSQIZNPIcGZdQc92UVJYAhtGAteCFg0yRaFm8f0P+CPtyGyHXJcGXnffjCybUCEx3FQ2G7U3/o9eIkVQ==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4.0"
}
},
"node_modules/cross-spawn": {
"version": "7.0.6",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz",
@@ -3520,6 +3582,21 @@
"dev": true,
"license": "MIT"
},
"node_modules/dunder-proto": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz",
"integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bind-apply-helpers": "^1.0.1",
"es-errors": "^1.3.0",
"gopd": "^1.2.0"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/enhanced-resolve": {
"version": "5.18.2",
"resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.2.tgz",
@@ -3547,6 +3624,26 @@
"url": "https://github.com/fb55/entities?sponsor=1"
}
},
"node_modules/es-define-property": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz",
"integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
}
},
"node_modules/es-errors": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz",
"integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
}
},
"node_modules/es-module-lexer": {
"version": "1.7.0",
"resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.7.0.tgz",
@@ -3554,6 +3651,19 @@
"dev": true,
"license": "MIT"
},
"node_modules/es-object-atoms": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz",
"integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/es-toolkit": {
"version": "1.39.7",
"resolved": "https://registry.npmjs.org/es-toolkit/-/es-toolkit-1.39.7.tgz",
@@ -3885,6 +3995,13 @@
"node": ">=0.10.0"
}
},
"node_modules/eventemitter3": {
"version": "4.0.7",
"resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz",
"integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==",
"dev": true,
"license": "MIT"
},
"node_modules/expect-type": {
"version": "1.2.2",
"resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.2.2.tgz",
@@ -4058,6 +4175,27 @@
"dev": true,
"license": "ISC"
},
"node_modules/follow-redirects": {
"version": "1.15.11",
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.11.tgz",
"integrity": "sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==",
"dev": true,
"funding": [
{
"type": "individual",
"url": "https://github.com/sponsors/RubenVerborgh"
}
],
"license": "MIT",
"engines": {
"node": ">=4.0"
},
"peerDependenciesMeta": {
"debug": {
"optional": true
}
}
},
"node_modules/fsevents": {
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
@@ -4073,6 +4211,55 @@
"node": "^8.16.0 || ^10.6.0 || >=11.0.0"
}
},
"node_modules/function-bind": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz",
"integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==",
"dev": true,
"license": "MIT",
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/get-intrinsic": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz",
"integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bind-apply-helpers": "^1.0.2",
"es-define-property": "^1.0.1",
"es-errors": "^1.3.0",
"es-object-atoms": "^1.1.1",
"function-bind": "^1.1.2",
"get-proto": "^1.0.1",
"gopd": "^1.2.0",
"has-symbols": "^1.1.0",
"hasown": "^2.0.2",
"math-intrinsics": "^1.1.0"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/get-proto": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz",
"integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==",
"dev": true,
"license": "MIT",
"dependencies": {
"dunder-proto": "^1.0.1",
"es-object-atoms": "^1.0.0"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/glob-parent": {
"version": "6.0.2",
"resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz",
@@ -4099,6 +4286,19 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/gopd": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz",
"integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/graceful-fs": {
"version": "4.2.11",
"resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz",
@@ -4123,6 +4323,32 @@
"node": ">=8"
}
},
"node_modules/has-symbols": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz",
"integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/hasown": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz",
"integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"function-bind": "^1.1.2"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/hast-util-from-dom": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/hast-util-from-dom/-/hast-util-from-dom-5.0.1.tgz",
@@ -4363,6 +4589,16 @@
"url": "https://opencollective.com/unified"
}
},
"node_modules/he": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/he/-/he-1.2.0.tgz",
"integrity": "sha512-F/1DnUGPopORZi0ni+CvrCgHQ5FyEAHRLSApuYWMmrbSwoN2Mn/7k+Gl38gJnR7yyDZk6WLXwiGod1JOWNDKGw==",
"dev": true,
"license": "MIT",
"bin": {
"he": "bin/he"
}
},
"node_modules/highlight.js": {
"version": "11.11.1",
"resolved": "https://registry.npmjs.org/highlight.js/-/highlight.js-11.11.1.tgz",
@@ -4372,6 +4608,19 @@
"node": ">=12.0.0"
}
},
"node_modules/html-encoding-sniffer": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-3.0.0.tgz",
"integrity": "sha512-oWv4T4yJ52iKrufjnyZPkrN0CH3QnrUqdB6In1g5Fe1mia8GmF36gnfNySxoZtxD5+NmYw1EElVXiBk93UeskA==",
"dev": true,
"license": "MIT",
"dependencies": {
"whatwg-encoding": "^2.0.0"
},
"engines": {
"node": ">=12"
}
},
"node_modules/html-void-elements": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/html-void-elements/-/html-void-elements-3.0.0.tgz",
@@ -4382,6 +4631,62 @@
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/http-proxy": {
"version": "1.18.1",
"resolved": "https://registry.npmjs.org/http-proxy/-/http-proxy-1.18.1.tgz",
"integrity": "sha512-7mz/721AbnJwIVbnaSv1Cz3Am0ZLT/UBwkC92VlxhXv/k/BBQfM2fXElQNC27BVGr0uwUpplYPQM9LnaBMR5NQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"eventemitter3": "^4.0.0",
"follow-redirects": "^1.0.0",
"requires-port": "^1.0.0"
},
"engines": {
"node": ">=8.0.0"
}
},
"node_modules/http-server": {
"version": "14.1.1",
"resolved": "https://registry.npmjs.org/http-server/-/http-server-14.1.1.tgz",
"integrity": "sha512-+cbxadF40UXd9T01zUHgA+rlo2Bg1Srer4+B4NwIHdaGxAGGv59nYRnGGDJ9LBk7alpS0US+J+bLLdQOOkJq4A==",
"dev": true,
"license": "MIT",
"dependencies": {
"basic-auth": "^2.0.1",
"chalk": "^4.1.2",
"corser": "^2.0.1",
"he": "^1.2.0",
"html-encoding-sniffer": "^3.0.0",
"http-proxy": "^1.18.1",
"mime": "^1.6.0",
"minimist": "^1.2.6",
"opener": "^1.5.1",
"portfinder": "^1.0.28",
"secure-compare": "3.0.1",
"union": "~0.5.0",
"url-join": "^4.0.1"
},
"bin": {
"http-server": "bin/http-server"
},
"engines": {
"node": ">=12"
}
},
"node_modules/iconv-lite": {
"version": "0.6.3",
"resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.6.3.tgz",
"integrity": "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==",
"dev": true,
"license": "MIT",
"dependencies": {
"safer-buffer": ">= 2.1.2 < 3.0.0"
},
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/ignore": {
"version": "5.3.2",
"resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz",
@@ -5008,6 +5313,16 @@
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/math-intrinsics": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz",
"integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
}
},
"node_modules/mdast": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/mdast/-/mdast-3.0.0.tgz",
@@ -5976,6 +6291,19 @@
"url": "https://github.com/sponsors/jonschlinkert"
}
},
"node_modules/mime": {
"version": "1.6.0",
"resolved": "https://registry.npmjs.org/mime/-/mime-1.6.0.tgz",
"integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==",
"dev": true,
"license": "MIT",
"bin": {
"mime": "cli.js"
},
"engines": {
"node": ">=4"
}
},
"node_modules/min-indent": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz",
@@ -6009,6 +6337,16 @@
"node": "*"
}
},
"node_modules/minimist": {
"version": "1.2.8",
"resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz",
"integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==",
"dev": true,
"license": "MIT",
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/minipass": {
"version": "7.1.2",
"resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz",
@@ -6124,6 +6462,19 @@
"tslib": "^2.0.3"
}
},
"node_modules/object-inspect": {
"version": "1.13.4",
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
"integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/open": {
"version": "8.4.2",
"resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz",
@@ -6142,6 +6493,16 @@
"url": "https://github.com/sponsors/sindresorhus"
}
},
"node_modules/opener": {
"version": "1.5.2",
"resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz",
"integrity": "sha512-ur5UIdyw5Y7yEj9wLzhqXiy6GZ3Mwx0yGI+5sMn2r0N0v3cKJvUmFH5yPP+WXh9e0xfyzyJX95D8l088DNFj7A==",
"dev": true,
"license": "(WTFPL OR MIT)",
"bin": {
"opener": "bin/opener-bin.js"
}
},
"node_modules/optionator": {
"version": "0.9.4",
"resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz",
@@ -6330,6 +6691,20 @@
"node": ">=18"
}
},
"node_modules/portfinder": {
"version": "1.0.38",
"resolved": "https://registry.npmjs.org/portfinder/-/portfinder-1.0.38.tgz",
"integrity": "sha512-rEwq/ZHlJIKw++XtLAO8PPuOQA/zaPJOZJ37BVuN97nLpMJeuDVLVGRwbFoBgLudgdTMP2hdRJP++H+8QOA3vg==",
"dev": true,
"license": "MIT",
"dependencies": {
"async": "^3.2.6",
"debug": "^4.3.6"
},
"engines": {
"node": ">= 10.12"
}
},
"node_modules/postcss": {
"version": "8.5.6",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz",
@@ -6680,6 +7055,22 @@
"node": ">=6"
}
},
"node_modules/qs": {
"version": "6.14.0",
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.0.tgz",
"integrity": "sha512-YWWTjgABSKcvs/nWBi9PycY/JiPJqOD4JA6o9Sej2AtvSGarXxKC3OQSk4pAarbdQlKAh5D4FCQkJNkW+GAn3w==",
"dev": true,
"license": "BSD-3-Clause",
"dependencies": {
"side-channel": "^1.1.0"
},
"engines": {
"node": ">=0.6"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/queue-microtask": {
"version": "1.2.3",
"resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
@@ -6959,6 +7350,13 @@
"url": "https://opencollective.com/unified"
}
},
"node_modules/requires-port": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/requires-port/-/requires-port-1.0.0.tgz",
"integrity": "sha512-KigOCHcocU3XODJxsu8i/j8T9tzT4adHiecwORRQ0ZZFcp7ahwXuRU1m+yuO90C5ZUyGeGfocHDI14M3L3yDAQ==",
"dev": true,
"license": "MIT"
},
"node_modules/resolve-from": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz",
@@ -7072,6 +7470,20 @@
"node": ">=6"
}
},
"node_modules/safe-buffer": {
"version": "5.1.2",
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz",
"integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==",
"dev": true,
"license": "MIT"
},
"node_modules/safer-buffer": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz",
"integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==",
"dev": true,
"license": "MIT"
},
"node_modules/scheduler": {
"version": "0.26.0",
"resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.26.0.tgz",
@@ -7079,6 +7491,13 @@
"dev": true,
"license": "MIT"
},
"node_modules/secure-compare": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/secure-compare/-/secure-compare-3.0.1.tgz",
"integrity": "sha512-AckIIV90rPDcBcglUwXPF3kg0P0qmPsPXAj6BBEENQE1p5yA1xfmDJzfi1Tappj37Pv2mVbKpL3Z1T+Nn7k1Qw==",
"dev": true,
"license": "MIT"
},
"node_modules/semver": {
"version": "7.7.2",
"resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz",
@@ -7122,6 +7541,82 @@
"node": ">=8"
}
},
"node_modules/side-channel": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz",
"integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0",
"object-inspect": "^1.13.3",
"side-channel-list": "^1.0.0",
"side-channel-map": "^1.0.1",
"side-channel-weakmap": "^1.0.2"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/side-channel-list": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz",
"integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0",
"object-inspect": "^1.13.3"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/side-channel-map": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz",
"integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bound": "^1.0.2",
"es-errors": "^1.3.0",
"get-intrinsic": "^1.2.5",
"object-inspect": "^1.13.3"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/side-channel-weakmap": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz",
"integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bound": "^1.0.2",
"es-errors": "^1.3.0",
"get-intrinsic": "^1.2.5",
"object-inspect": "^1.13.3",
"side-channel-map": "^1.0.1"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/siginfo": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz",
@@ -7904,6 +8399,18 @@
"integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==",
"license": "MIT"
},
"node_modules/union": {
"version": "0.5.0",
"resolved": "https://registry.npmjs.org/union/-/union-0.5.0.tgz",
"integrity": "sha512-N6uOhuW6zO95P3Mel2I2zMsbsanvvtgn6jVqJv4vbVcz/JN0OkL9suomjQGmWtxJQXOCqUJvquc1sMeNz/IwlA==",
"dev": true,
"dependencies": {
"qs": "^6.4.0"
},
"engines": {
"node": ">= 0.8.0"
}
},
"node_modules/unist-util-find-after": {
"version": "5.0.0",
"resolved": "https://registry.npmjs.org/unist-util-find-after/-/unist-util-find-after-5.0.0.tgz",
@@ -8073,6 +8580,13 @@
"punycode": "^2.1.0"
}
},
"node_modules/url-join": {
"version": "4.0.1",
"resolved": "https://registry.npmjs.org/url-join/-/url-join-4.0.1.tgz",
"integrity": "sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==",
"dev": true,
"license": "MIT"
},
"node_modules/util-deprecate": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz",
@@ -8447,6 +8961,19 @@
"dev": true,
"license": "MIT"
},
"node_modules/whatwg-encoding": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/whatwg-encoding/-/whatwg-encoding-2.0.0.tgz",
"integrity": "sha512-p41ogyeMUrw3jWclHWTQg1k05DSVXPLcVxRTYsXUk+ZooOCZLcoYgPZ/HL/D/N+uQPOtcp1me1WhBEaX02mhWg==",
"dev": true,
"license": "MIT",
"dependencies": {
"iconv-lite": "0.6.3"
},
"engines": {
"node": ">=12"
}
},
"node_modules/which": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",
+1
View File
@@ -52,6 +52,7 @@
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
"http-server": "^14.1.1",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3",
"playwright": "^1.53.0",
+4 -2
View File
@@ -2,8 +2,10 @@ import { defineConfig } from '@playwright/test';
export default defineConfig({
webServer: {
command: 'npm run build && npx http-server ../public -p 8181',
port: 8181
command: 'npm run build && http-server ../public -p 8181',
port: 8181,
timeout: 120000,
reuseExistingServer: false
},
testDir: 'e2e'
});
+3 -1
View File
@@ -31,7 +31,8 @@ import type {
DatabaseMessageExtraAudioFile,
DatabaseMessageExtraImageFile,
DatabaseMessageExtraTextFile,
DatabaseMessageExtraPdfFile
DatabaseMessageExtraPdfFile,
DatabaseMessageExtraLegacyContext
} from '$lib/types/database';
import type {
@@ -73,6 +74,7 @@ declare global {
DatabaseMessageExtraImageFile,
DatabaseMessageExtraTextFile,
DatabaseMessageExtraPdfFile,
DatabaseMessageExtraLegacyContext,
SettingsConfigValue,
SettingsFieldConfig,
SettingsConfigType,
@@ -94,6 +94,17 @@
attachmentIndex: index,
textContent: attachment.content
});
} else if (attachment.type === 'context') {
// Legacy format from old webui - treat as text file
items.push({
id: `attachment-${index}`,
name: attachment.name,
type: 'text',
isImage: false,
attachment,
attachmentIndex: index,
textContent: attachment.content
});
} else if (attachment.type === 'audioFile') {
items.push({
id: `attachment-${index}`,
@@ -26,6 +26,7 @@
MimeTypeImage,
MimeTypeText
} from '$lib/enums/files';
import { isIMEComposing } from '$lib/utils/is-ime-composing';
interface Props {
class?: string;
@@ -97,7 +98,7 @@
}
async function handleKeydown(event: KeyboardEvent) {
if (event.key === 'Enter' && !event.shiftKey) {
if (event.key === 'Enter' && !event.shiftKey && !isIMEComposing(event)) {
event.preventDefault();
if ((!message.trim() && uploadedFiles.length === 0) || disabled || isLoading) return;
@@ -3,6 +3,8 @@
import { Button } from '$lib/components/ui/button';
import ChatFormActionFileAttachments from './ChatFormActionFileAttachments.svelte';
import ChatFormActionRecord from './ChatFormActionRecord.svelte';
import ChatFormModelSelector from './ChatFormModelSelector.svelte';
import { config } from '$lib/stores/settings.svelte';
import type { FileTypeCategory } from '$lib/enums/files';
interface Props {
@@ -26,32 +28,36 @@
onMicClick,
onStop
}: Props = $props();
let currentConfig = $derived(config());
</script>
<div class="flex items-center justify-between gap-1 {className}">
<ChatFormActionFileAttachments {disabled} {onFileUpload} />
<div class="flex w-full items-center gap-2 {className}">
<ChatFormActionFileAttachments class="mr-auto" {disabled} {onFileUpload} />
<div class="flex gap-2">
{#if isLoading}
<Button
type="button"
onclick={onStop}
class="h-8 w-8 bg-transparent p-0 hover:bg-destructive/20"
>
<span class="sr-only">Stop</span>
<Square class="h-8 w-8 fill-destructive stroke-destructive" />
</Button>
{:else}
<ChatFormActionRecord {disabled} {isLoading} {isRecording} {onMicClick} />
{#if currentConfig.modelSelectorEnabled}
<ChatFormModelSelector class="shrink-0" />
{/if}
<Button
type="submit"
disabled={!canSend || disabled || isLoading}
class="h-8 w-8 rounded-full p-0"
>
<span class="sr-only">Send</span>
<ArrowUp class="h-12 w-12" />
</Button>
{/if}
</div>
{#if isLoading}
<Button
type="button"
onclick={onStop}
class="h-8 w-8 bg-transparent p-0 hover:bg-destructive/20"
>
<span class="sr-only">Stop</span>
<Square class="h-8 w-8 fill-destructive stroke-destructive" />
</Button>
{:else}
<ChatFormActionRecord {disabled} {isLoading} {isRecording} {onMicClick} />
<Button
type="submit"
disabled={!canSend || disabled || isLoading}
class="h-8 w-8 rounded-full p-0"
>
<span class="sr-only">Send</span>
<ArrowUp class="h-12 w-12" />
</Button>
{/if}
</div>
@@ -0,0 +1,358 @@
<script lang="ts">
import { onMount, tick } from 'svelte';
import { ChevronDown, Loader2 } from '@lucide/svelte';
import { cn } from '$lib/components/ui/utils';
import { portalToBody } from '$lib/utils/portal-to-body';
import {
fetchModels,
modelOptions,
modelsError,
modelsLoading,
modelsUpdating,
selectModel,
selectedModelId
} from '$lib/stores/models.svelte';
import type { ModelOption } from '$lib/types/models';
interface Props {
class?: string;
}
let { class: className = '' }: Props = $props();
let options = $derived(modelOptions());
let loading = $derived(modelsLoading());
let updating = $derived(modelsUpdating());
let error = $derived(modelsError());
let activeId = $derived(selectedModelId());
let isMounted = $state(false);
let isOpen = $state(false);
let container: HTMLDivElement | null = null;
let triggerButton = $state<HTMLButtonElement | null>(null);
let menuRef = $state<HTMLDivElement | null>(null);
let menuPosition = $state<{
top: number;
left: number;
width: number;
placement: 'top' | 'bottom';
maxHeight: number;
} | null>(null);
let lockedWidth: number | null = null;
onMount(async () => {
try {
await fetchModels();
} catch (error) {
console.error('Unable to load models:', error);
} finally {
isMounted = true;
}
});
function handlePointerDown(event: PointerEvent) {
if (!container) return;
const target = event.target as Node | null;
if (target && !container.contains(target) && !(menuRef && menuRef.contains(target))) {
closeMenu();
}
}
function handleKeydown(event: KeyboardEvent) {
if (event.key === 'Escape') {
closeMenu();
}
}
function handleResize() {
if (isOpen) {
updateMenuPosition();
}
}
function handleScroll() {
if (isOpen) {
updateMenuPosition();
}
}
async function handleSelect(value: string | undefined) {
if (!value) return;
const option = options.find((item) => item.id === value);
if (!option) {
console.error('Model is no longer available');
return;
}
try {
await selectModel(option.id);
} catch (error) {
console.error('Failed to switch model:', error);
}
}
const VIEWPORT_GUTTER = 8;
const MENU_OFFSET = 6;
const MENU_MAX_WIDTH = 320;
async function openMenu() {
if (loading || updating) return;
isOpen = true;
await tick();
updateMenuPosition();
requestAnimationFrame(() => updateMenuPosition());
}
function toggleOpen() {
if (loading || updating) return;
if (isOpen) {
closeMenu();
} else {
void openMenu();
}
}
function closeMenu() {
if (!isOpen) return;
isOpen = false;
menuPosition = null;
lockedWidth = null;
}
async function handleOptionSelect(optionId: string) {
try {
await handleSelect(optionId);
} finally {
closeMenu();
}
}
$effect(() => {
if (loading || updating) {
closeMenu();
}
});
$effect(() => {
const optionCount = options.length;
if (!isOpen || optionCount <= 0) return;
queueMicrotask(() => updateMenuPosition());
});
function updateMenuPosition() {
if (!isOpen || !triggerButton || !menuRef) return;
const triggerRect = triggerButton.getBoundingClientRect();
const viewportWidth = window.innerWidth;
const viewportHeight = window.innerHeight;
if (viewportWidth === 0 || viewportHeight === 0) return;
const scrollWidth = menuRef.scrollWidth;
const scrollHeight = menuRef.scrollHeight;
const availableWidth = Math.max(0, viewportWidth - VIEWPORT_GUTTER * 2);
const constrainedMaxWidth = Math.min(MENU_MAX_WIDTH, availableWidth || MENU_MAX_WIDTH);
const safeMaxWidth =
constrainedMaxWidth > 0 ? constrainedMaxWidth : Math.min(MENU_MAX_WIDTH, viewportWidth);
const desiredMinWidth = Math.min(160, safeMaxWidth || 160);
let width = lockedWidth;
if (width === null) {
const naturalWidth = Math.min(scrollWidth, safeMaxWidth);
const baseWidth = Math.max(triggerRect.width, naturalWidth, desiredMinWidth);
width = Math.min(baseWidth, safeMaxWidth || baseWidth);
lockedWidth = width;
} else {
width = Math.min(Math.max(width, desiredMinWidth), safeMaxWidth || width);
}
if (width > 0) {
menuRef.style.width = `${width}px`;
}
const availableBelow = Math.max(
0,
viewportHeight - VIEWPORT_GUTTER - triggerRect.bottom - MENU_OFFSET
);
const availableAbove = Math.max(0, triggerRect.top - VIEWPORT_GUTTER - MENU_OFFSET);
const viewportAllowance = Math.max(0, viewportHeight - VIEWPORT_GUTTER * 2);
const fallbackAllowance = Math.max(1, viewportAllowance > 0 ? viewportAllowance : scrollHeight);
function computePlacement(placement: 'top' | 'bottom') {
const available = placement === 'bottom' ? availableBelow : availableAbove;
const allowedHeight =
available > 0 ? Math.min(available, fallbackAllowance) : fallbackAllowance;
const maxHeight = Math.min(scrollHeight, allowedHeight);
const height = Math.max(0, maxHeight);
let top: number;
if (placement === 'bottom') {
const rawTop = triggerRect.bottom + MENU_OFFSET;
const minTop = VIEWPORT_GUTTER;
const maxTop = viewportHeight - VIEWPORT_GUTTER - height;
if (maxTop < minTop) {
top = minTop;
} else {
top = Math.min(Math.max(rawTop, minTop), maxTop);
}
} else {
const rawTop = triggerRect.top - MENU_OFFSET - height;
const minTop = VIEWPORT_GUTTER;
const maxTop = viewportHeight - VIEWPORT_GUTTER - height;
if (maxTop < minTop) {
top = minTop;
} else {
top = Math.max(Math.min(rawTop, maxTop), minTop);
}
}
return { placement, top, height, maxHeight };
}
const belowMetrics = computePlacement('bottom');
const aboveMetrics = computePlacement('top');
let metrics = belowMetrics;
if (scrollHeight > belowMetrics.maxHeight && aboveMetrics.maxHeight > belowMetrics.maxHeight) {
metrics = aboveMetrics;
}
menuRef.style.maxHeight = metrics.maxHeight > 0 ? `${Math.round(metrics.maxHeight)}px` : '';
let left = triggerRect.right - width;
const maxLeft = viewportWidth - VIEWPORT_GUTTER - width;
if (maxLeft < VIEWPORT_GUTTER) {
left = VIEWPORT_GUTTER;
} else {
if (left > maxLeft) {
left = maxLeft;
}
if (left < VIEWPORT_GUTTER) {
left = VIEWPORT_GUTTER;
}
}
menuPosition = {
top: Math.round(metrics.top),
left: Math.round(left),
width: Math.round(width),
placement: metrics.placement,
maxHeight: Math.round(metrics.maxHeight)
};
}
function getDisplayOption(): ModelOption | undefined {
if (activeId) {
return options.find((option) => option.id === activeId);
}
return options[0];
}
</script>
<svelte:window onresize={handleResize} onscroll={handleScroll} />
<svelte:document onpointerdown={handlePointerDown} onkeydown={handleKeydown} />
<div
class={cn('relative z-10 flex max-w-[200px] min-w-[120px] flex-col items-end gap-1', className)}
bind:this={container}
>
{#if loading && options.length === 0 && !isMounted}
<div class="flex items-center gap-2 text-xs text-muted-foreground">
<Loader2 class="h-4 w-4 animate-spin" />
Loading models…
</div>
{:else if options.length === 0}
<p class="text-xs text-muted-foreground">No models available.</p>
{:else}
{@const selectedOption = getDisplayOption()}
<div class="relative w-full">
<button
type="button"
class={cn(
'flex w-full items-center justify-end gap-2 rounded-md px-2 py-1 text-sm text-muted-foreground transition hover:text-foreground focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60',
isOpen ? 'text-foreground' : ''
)}
aria-haspopup="listbox"
aria-expanded={isOpen}
onclick={toggleOpen}
bind:this={triggerButton}
disabled={loading || updating}
>
<span class="max-w-[160px] truncate text-right font-medium">
{selectedOption?.name || 'Select model'}
</span>
{#if updating}
<Loader2 class="h-3.5 w-3.5 animate-spin text-muted-foreground" />
{:else}
<ChevronDown
class={cn(
'h-4 w-4 text-muted-foreground transition-transform',
isOpen ? 'rotate-180 text-foreground' : ''
)}
/>
{/if}
</button>
{#if isOpen}
<div
bind:this={menuRef}
use:portalToBody
class={cn(
'fixed z-[1000] overflow-hidden rounded-md border bg-popover shadow-lg transition-opacity',
menuPosition ? 'opacity-100' : 'pointer-events-none opacity-0'
)}
role="listbox"
style:top={menuPosition ? `${menuPosition.top}px` : undefined}
style:left={menuPosition ? `${menuPosition.left}px` : undefined}
style:width={menuPosition ? `${menuPosition.width}px` : undefined}
data-placement={menuPosition?.placement ?? 'bottom'}
>
<div
class="overflow-y-auto py-1"
style:max-height={menuPosition && menuPosition.maxHeight > 0
? `${menuPosition.maxHeight}px`
: undefined}
>
{#each options as option (option.id)}
<button
type="button"
class={cn(
'flex w-full flex-col items-start gap-0.5 px-3 py-2 text-left text-sm transition hover:bg-muted focus:bg-muted focus:outline-none',
option.id === selectedOption?.id ? 'bg-accent text-accent-foreground' : ''
)}
role="option"
aria-selected={option.id === selectedOption?.id}
onclick={() => handleOptionSelect(option.id)}
>
<span class="block w-full truncate font-medium" title={option.name}>
{option.name}
</span>
{#if option.description}
<span class="text-xs text-muted-foreground">{option.description}</span>
{/if}
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
{#if error}
<p class="text-xs text-destructive">{error}</p>
{/if}
</div>
@@ -1,6 +1,7 @@
<script lang="ts">
import { getDeletionInfo } from '$lib/stores/chat.svelte';
import { copyToClipboard } from '$lib/utils/copy';
import { isIMEComposing } from '$lib/utils/is-ime-composing';
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
import ChatMessageUser from './ChatMessageUser.svelte';
@@ -93,7 +94,9 @@
}
function handleEditKeydown(event: KeyboardEvent) {
if (event.key === 'Enter' && !event.shiftKey) {
// Check for IME composition using isComposing property and keyCode 229 (specifically for IME composition on Safari)
// This prevents saving edit when confirming IME word selection (e.g., Japanese/Chinese input)
if (event.key === 'Enter' && !event.shiftKey && !isIMEComposing(event)) {
event.preventDefault();
handleSaveEdit();
} else if (event.key === 'Escape') {
@@ -10,6 +10,7 @@
import ChatMessageActions from './ChatMessageActions.svelte';
import Label from '$lib/components/ui/label/label.svelte';
import { config } from '$lib/stores/settings.svelte';
import { modelName as serverModelName } from '$lib/stores/server.svelte';
import { copyToClipboard } from '$lib/utils/copy';
interface Props {
@@ -70,6 +71,23 @@
}: Props = $props();
const processingState = useProcessingState();
let currentConfig = $derived(config());
let serverModel = $derived(serverModelName());
let displayedModel = $derived((): string | null => {
if (!currentConfig.showModelInfo) return null;
if (currentConfig.modelSelectorEnabled) {
return message.model ?? null;
}
return serverModel;
});
function handleCopyModel() {
const model = displayedModel();
void copyToClipboard(model ?? '');
}
</script>
<div
@@ -142,7 +160,7 @@
</div>
{/if}
{#if config().showModelInfo && message.model}
{#if displayedModel()}
<span class="mt-6 mb-4 inline-flex items-center gap-1 text-xs text-muted-foreground">
<Package class="h-3.5 w-3.5" />
@@ -150,9 +168,9 @@
<button
class="inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
onclick={() => copyToClipboard(message.model)}
onclick={handleCopyModel}
>
{message.model}
{displayedModel()}
<Copy class="ml-1 h-3 w-3 " />
</button>
@@ -7,18 +7,19 @@
const processingState = useProcessingState();
let isCurrentConversationLoading = $derived(isLoading());
let processingDetails = $derived(processingState.getProcessingDetails());
let showSlotsInfo = $derived(isCurrentConversationLoading || config().keepStatsVisible);
let showSlotsInfo = $derived(isLoading() || config().keepStatsVisible);
// Track loading state reactively by checking if conversation ID is in loading conversations array
$effect(() => {
const keepStatsVisible = config().keepStatsVisible;
if (keepStatsVisible || isLoading()) {
if (keepStatsVisible || isCurrentConversationLoading) {
processingState.startMonitoring();
}
if (!isLoading() && !keepStatsVisible) {
if (!isCurrentConversationLoading && !keepStatsVisible) {
setTimeout(() => {
if (!config().keepStatsVisible) {
processingState.stopMonitoring();
@@ -27,18 +28,20 @@
}
});
// Update processing state from stored timings
$effect(() => {
activeConversation();
const conversation = activeConversation();
const messages = activeMessages() as DatabaseMessage[];
const keepStatsVisible = config().keepStatsVisible;
if (keepStatsVisible) {
if (keepStatsVisible && conversation) {
if (messages.length === 0) {
slotsService.clearState();
slotsService.clearConversationState(conversation.id);
return;
}
// Search backwards through messages to find most recent assistant message with timing data
// Using reverse iteration for performance - avoids array copy and stops at first match
let foundTimingData = false;
for (let i = messages.length - 1; i >= 0; i--) {
@@ -47,15 +50,18 @@
foundTimingData = true;
slotsService
.updateFromTimingData({
prompt_n: message.timings.prompt_n || 0,
predicted_n: message.timings.predicted_n || 0,
predicted_per_second:
message.timings.predicted_n && message.timings.predicted_ms
? (message.timings.predicted_n / message.timings.predicted_ms) * 1000
: 0,
cache_n: message.timings.cache_n || 0
})
.updateFromTimingData(
{
prompt_n: message.timings.prompt_n || 0,
predicted_n: message.timings.predicted_n || 0,
predicted_per_second:
message.timings.predicted_n && message.timings.predicted_ms
? (message.timings.predicted_n / message.timings.predicted_ms) * 1000
: 0,
cache_n: message.timings.cache_n || 0
},
conversation.id
)
.catch((error) => {
console.warn('Failed to update processing state from stored timings:', error);
});
@@ -64,7 +70,7 @@
}
if (!foundTimingData) {
slotsService.clearState();
slotsService.clearConversationState(conversation.id);
}
}
});
@@ -83,6 +83,8 @@
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
let isCurrentConversationLoading = $derived(isLoading());
async function handleDeleteConfirm() {
const conversation = activeConversation();
if (conversation) {
@@ -254,7 +256,7 @@
});
$effect(() => {
if (isLoading() && autoScrollEnabled) {
if (isCurrentConversationLoading && autoScrollEnabled) {
scrollInterval = setInterval(scrollChatToBottom, AUTO_SCROLL_INTERVAL);
} else if (scrollInterval) {
clearInterval(scrollInterval);
@@ -305,7 +307,7 @@
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl pb-4">
<ChatForm
isLoading={isLoading()}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}
onSend={handleSendMessage}
@@ -348,7 +350,7 @@
<div in:fly={{ y: 10, duration: 250, delay: 300 }}>
<ChatForm
isLoading={isLoading()}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}
onSend={handleSendMessage}
@@ -4,14 +4,16 @@
Funnel,
AlertTriangle,
Brain,
Cog,
Code,
Monitor,
Sun,
Moon,
ChevronLeft,
ChevronRight
ChevronRight,
Database
} from '@lucide/svelte';
import { ChatSettingsFooter, ChatSettingsFields } from '$lib/components/app';
import ImportExportTab from './ImportExportTab.svelte';
import * as Dialog from '$lib/components/ui/dialog';
import { ScrollArea } from '$lib/components/ui/scroll-area';
import { config, updateMultipleConfig } from '$lib/stores/settings.svelte';
@@ -88,9 +90,59 @@
]
},
{
title: 'Samplers',
title: 'Sampling',
icon: Funnel,
fields: [
{
key: 'temperature',
label: 'Temperature',
type: 'input'
},
{
key: 'dynatemp_range',
label: 'Dynamic temperature range',
type: 'input'
},
{
key: 'dynatemp_exponent',
label: 'Dynamic temperature exponent',
type: 'input'
},
{
key: 'top_k',
label: 'Top K',
type: 'input'
},
{
key: 'top_p',
label: 'Top P',
type: 'input'
},
{
key: 'min_p',
label: 'Min P',
type: 'input'
},
{
key: 'xtc_probability',
label: 'XTC probability',
type: 'input'
},
{
key: 'xtc_threshold',
label: 'XTC threshold',
type: 'input'
},
{
key: 'typ_p',
label: 'Typical P',
type: 'input'
},
{
key: 'max_tokens',
label: 'Max tokens',
type: 'input'
},
{
key: 'samplers',
label: 'Samplers',
@@ -152,68 +204,27 @@
key: 'showThoughtInProgress',
label: 'Show thought in progress',
type: 'checkbox'
},
{
key: 'disableReasoningFormat',
label:
'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.',
type: 'checkbox'
}
]
},
{
title: 'Advanced',
icon: Cog,
title: 'Import/Export',
icon: Database,
fields: []
},
{
title: 'Developer',
icon: Code,
fields: [
{
key: 'temperature',
label: 'Temperature',
type: 'input'
key: 'modelSelectorEnabled',
label: 'Enable model selector',
type: 'checkbox'
},
{
key: 'dynatemp_range',
label: 'Dynamic temperature range',
type: 'input'
},
{
key: 'dynatemp_exponent',
label: 'Dynamic temperature exponent',
type: 'input'
},
{
key: 'top_k',
label: 'Top K',
type: 'input'
},
{
key: 'top_p',
label: 'Top P',
type: 'input'
},
{
key: 'min_p',
label: 'Min P',
type: 'input'
},
{
key: 'xtc_probability',
label: 'XTC probability',
type: 'input'
},
{
key: 'xtc_threshold',
label: 'XTC threshold',
type: 'input'
},
{
key: 'typ_p',
label: 'Typical P',
type: 'input'
},
{
key: 'max_tokens',
label: 'Max tokens',
type: 'input'
key: 'disableReasoningFormat',
label: 'Show raw LLM output',
type: 'checkbox'
},
{
key: 'custom',
@@ -456,21 +467,25 @@
<ScrollArea class="max-h-[calc(100dvh-13.5rem)] flex-1 md:max-h-[calc(100vh-13.5rem)]">
<div class="space-y-6 p-4 md:p-6">
<div>
<div class="grid">
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">
<currentSection.icon class="h-5 w-5" />
<h3 class="text-lg font-semibold">{currentSection.title}</h3>
</div>
<div class="space-y-6">
<ChatSettingsFields
fields={currentSection.fields}
{localConfig}
onConfigChange={handleConfigChange}
onThemeChange={handleThemeChange}
/>
</div>
{#if currentSection.title === 'Import/Export'}
<ImportExportTab />
{:else}
<div class="space-y-6">
<ChatSettingsFields
fields={currentSection.fields}
{localConfig}
onConfigChange={handleConfigChange}
onThemeChange={handleThemeChange}
/>
</div>
{/if}
</div>
<div class="mt-8 border-t pt-6">
@@ -0,0 +1,249 @@
<script lang="ts">
import { Search, X } from '@lucide/svelte';
import * as Dialog from '$lib/components/ui/dialog';
import { Button } from '$lib/components/ui/button';
import { Input } from '$lib/components/ui/input';
import { Checkbox } from '$lib/components/ui/checkbox';
import { ScrollArea } from '$lib/components/ui/scroll-area';
import { SvelteSet } from 'svelte/reactivity';
interface Props {
conversations: DatabaseConversation[];
messageCountMap?: Map<string, number>;
mode: 'export' | 'import';
onCancel: () => void;
onConfirm: (selectedConversations: DatabaseConversation[]) => void;
open?: boolean;
}
let {
conversations,
messageCountMap = new Map(),
mode,
onCancel,
onConfirm,
open = $bindable(false)
}: Props = $props();
let searchQuery = $state('');
let selectedIds = $state.raw<SvelteSet<string>>(new SvelteSet(conversations.map((c) => c.id)));
let lastClickedId = $state<string | null>(null);
let filteredConversations = $derived(
conversations.filter((conv) => {
const name = conv.name || 'Untitled conversation';
return name.toLowerCase().includes(searchQuery.toLowerCase());
})
);
let allSelected = $derived(
filteredConversations.length > 0 &&
filteredConversations.every((conv) => selectedIds.has(conv.id))
);
let someSelected = $derived(
filteredConversations.some((conv) => selectedIds.has(conv.id)) && !allSelected
);
function toggleConversation(id: string, shiftKey: boolean = false) {
const newSet = new SvelteSet(selectedIds);
if (shiftKey && lastClickedId !== null) {
const lastIndex = filteredConversations.findIndex((c) => c.id === lastClickedId);
const currentIndex = filteredConversations.findIndex((c) => c.id === id);
if (lastIndex !== -1 && currentIndex !== -1) {
const start = Math.min(lastIndex, currentIndex);
const end = Math.max(lastIndex, currentIndex);
const shouldSelect = !newSet.has(id);
for (let i = start; i <= end; i++) {
if (shouldSelect) {
newSet.add(filteredConversations[i].id);
} else {
newSet.delete(filteredConversations[i].id);
}
}
selectedIds = newSet;
return;
}
}
if (newSet.has(id)) {
newSet.delete(id);
} else {
newSet.add(id);
}
selectedIds = newSet;
lastClickedId = id;
}
function toggleAll() {
if (allSelected) {
const newSet = new SvelteSet(selectedIds);
filteredConversations.forEach((conv) => newSet.delete(conv.id));
selectedIds = newSet;
} else {
const newSet = new SvelteSet(selectedIds);
filteredConversations.forEach((conv) => newSet.add(conv.id));
selectedIds = newSet;
}
}
function handleConfirm() {
const selected = conversations.filter((conv) => selectedIds.has(conv.id));
onConfirm(selected);
}
function handleCancel() {
selectedIds = new SvelteSet(conversations.map((c) => c.id));
searchQuery = '';
lastClickedId = null;
onCancel();
}
let previousOpen = $state(false);
$effect(() => {
if (open && !previousOpen) {
selectedIds = new SvelteSet(conversations.map((c) => c.id));
searchQuery = '';
lastClickedId = null;
} else if (!open && previousOpen) {
onCancel();
}
previousOpen = open;
});
</script>
<Dialog.Root bind:open>
<Dialog.Portal>
<Dialog.Overlay class="z-[1000000]" />
<Dialog.Content class="z-[1000001] max-w-2xl">
<Dialog.Header>
<Dialog.Title>
Select Conversations to {mode === 'export' ? 'Export' : 'Import'}
</Dialog.Title>
<Dialog.Description>
{#if mode === 'export'}
Choose which conversations you want to export. Selected conversations will be downloaded
as a JSON file.
{:else}
Choose which conversations you want to import. Selected conversations will be merged
with your existing conversations.
{/if}
</Dialog.Description>
</Dialog.Header>
<div class="space-y-4">
<div class="relative">
<Search class="absolute top-1/2 left-3 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
<Input bind:value={searchQuery} placeholder="Search conversations..." class="pr-9 pl-9" />
{#if searchQuery}
<button
class="absolute top-1/2 right-3 -translate-y-1/2 text-muted-foreground hover:text-foreground"
onclick={() => (searchQuery = '')}
type="button"
>
<X class="h-4 w-4" />
</button>
{/if}
</div>
<div class="flex items-center justify-between text-sm text-muted-foreground">
<span>
{selectedIds.size} of {conversations.length} selected
{#if searchQuery}
({filteredConversations.length} shown)
{/if}
</span>
</div>
<div class="overflow-hidden rounded-md border">
<ScrollArea class="h-[400px]">
<table class="w-full">
<thead class="sticky top-0 z-10 bg-muted">
<tr class="border-b">
<th class="w-12 p-3 text-left">
<Checkbox
checked={allSelected}
indeterminate={someSelected}
onCheckedChange={toggleAll}
/>
</th>
<th class="p-3 text-left text-sm font-medium">Conversation Name</th>
<th class="w-32 p-3 text-left text-sm font-medium">Messages</th>
</tr>
</thead>
<tbody>
{#if filteredConversations.length === 0}
<tr>
<td colspan="3" class="p-8 text-center text-sm text-muted-foreground">
{#if searchQuery}
No conversations found matching "{searchQuery}"
{:else}
No conversations available
{/if}
</td>
</tr>
{:else}
{#each filteredConversations as conv (conv.id)}
<tr
class="cursor-pointer border-b transition-colors hover:bg-muted/50"
onclick={(e) => toggleConversation(conv.id, e.shiftKey)}
>
<td class="p-3">
<Checkbox
checked={selectedIds.has(conv.id)}
onclick={(e) => {
e.preventDefault();
e.stopPropagation();
toggleConversation(conv.id, e.shiftKey);
}}
/>
</td>
<td class="p-3 text-sm">
<div
class="max-w-[17rem] truncate"
title={conv.name || 'Untitled conversation'}
>
{conv.name || 'Untitled conversation'}
</div>
</td>
<td class="p-3 text-sm text-muted-foreground">
{messageCountMap.get(conv.id) ?? 0}
</td>
</tr>
{/each}
{/if}
</tbody>
</table>
</ScrollArea>
</div>
</div>
<Dialog.Footer>
<Button variant="outline" onclick={handleCancel}>Cancel</Button>
<Button onclick={handleConfirm} disabled={selectedIds.size === 0}>
{mode === 'export' ? 'Export' : 'Import'} ({selectedIds.size})
</Button>
</Dialog.Footer>
</Dialog.Content>
</Dialog.Portal>
</Dialog.Root>
@@ -0,0 +1,255 @@
<script lang="ts">
import { Download, Upload } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import ConversationSelectionDialog from './ConversationSelectionDialog.svelte';
import { DatabaseStore } from '$lib/stores/database';
import type { ExportedConversations } from '$lib/types/database';
import { createMessageCountMap } from '$lib/utils/conversation-utils';
import { chatStore } from '$lib/stores/chat.svelte';
let exportedConversations = $state<DatabaseConversation[]>([]);
let importedConversations = $state<DatabaseConversation[]>([]);
let showExportSummary = $state(false);
let showImportSummary = $state(false);
let showExportDialog = $state(false);
let showImportDialog = $state(false);
let availableConversations = $state<DatabaseConversation[]>([]);
let messageCountMap = $state<Map<string, number>>(new Map());
let fullImportData = $state<Array<{ conv: DatabaseConversation; messages: DatabaseMessage[] }>>(
[]
);
async function handleExportClick() {
try {
const allConversations = await DatabaseStore.getAllConversations();
if (allConversations.length === 0) {
alert('No conversations to export');
return;
}
const conversationsWithMessages = await Promise.all(
allConversations.map(async (conv) => {
const messages = await DatabaseStore.getConversationMessages(conv.id);
return { conv, messages };
})
);
messageCountMap = createMessageCountMap(conversationsWithMessages);
availableConversations = allConversations;
showExportDialog = true;
} catch (err) {
console.error('Failed to load conversations:', err);
alert('Failed to load conversations');
}
}
async function handleExportConfirm(selectedConversations: DatabaseConversation[]) {
try {
const allData: ExportedConversations = await Promise.all(
selectedConversations.map(async (conv) => {
const messages = await DatabaseStore.getConversationMessages(conv.id);
return { conv: $state.snapshot(conv), messages: $state.snapshot(messages) };
})
);
const blob = new Blob([JSON.stringify(allData, null, 2)], {
type: 'application/json'
});
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `conversations_${new Date().toISOString().split('T')[0]}.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
exportedConversations = selectedConversations;
showExportSummary = true;
showImportSummary = false;
showExportDialog = false;
} catch (err) {
console.error('Export failed:', err);
alert('Failed to export conversations');
}
}
async function handleImportClick() {
try {
const input = document.createElement('input');
input.type = 'file';
input.accept = '.json';
input.onchange = async (e) => {
const file = (e.target as HTMLInputElement)?.files?.[0];
if (!file) return;
try {
const text = await file.text();
const parsedData = JSON.parse(text);
let importedData: ExportedConversations;
if (Array.isArray(parsedData)) {
importedData = parsedData;
} else if (
parsedData &&
typeof parsedData === 'object' &&
'conv' in parsedData &&
'messages' in parsedData
) {
// Single conversation object
importedData = [parsedData];
} else {
throw new Error(
'Invalid file format: expected array of conversations or single conversation object'
);
}
fullImportData = importedData;
availableConversations = importedData.map(
(item: { conv: DatabaseConversation; messages: DatabaseMessage[] }) => item.conv
);
messageCountMap = createMessageCountMap(importedData);
showImportDialog = true;
} catch (err: unknown) {
const message = err instanceof Error ? err.message : 'Unknown error';
console.error('Failed to parse file:', err);
alert(`Failed to parse file: ${message}`);
}
};
input.click();
} catch (err) {
console.error('Import failed:', err);
alert('Failed to import conversations');
}
}
async function handleImportConfirm(selectedConversations: DatabaseConversation[]) {
try {
const selectedIds = new Set(selectedConversations.map((c) => c.id));
const selectedData = $state
.snapshot(fullImportData)
.filter((item) => selectedIds.has(item.conv.id));
await DatabaseStore.importConversations(selectedData);
await chatStore.loadConversations();
importedConversations = selectedConversations;
showImportSummary = true;
showExportSummary = false;
showImportDialog = false;
} catch (err) {
console.error('Import failed:', err);
alert('Failed to import conversations. Please check the file format.');
}
}
</script>
<div class="space-y-6">
<div class="space-y-4">
<div class="grid">
<h4 class="mb-2 text-sm font-medium">Export Conversations</h4>
<p class="mb-4 text-sm text-muted-foreground">
Download all your conversations as a JSON file. This includes all messages, attachments, and
conversation history.
</p>
<Button
class="w-full justify-start justify-self-start md:w-auto"
onclick={handleExportClick}
variant="outline"
>
<Download class="mr-2 h-4 w-4" />
Export conversations
</Button>
{#if showExportSummary && exportedConversations.length > 0}
<div class="mt-4 grid overflow-x-auto rounded-lg border border-border/50 bg-muted/30 p-4">
<h5 class="mb-2 text-sm font-medium">
Exported {exportedConversations.length} conversation{exportedConversations.length === 1
? ''
: 's'}
</h5>
<ul class="space-y-1 text-sm text-muted-foreground">
{#each exportedConversations.slice(0, 10) as conv (conv.id)}
<li class="truncate">{conv.name || 'Untitled conversation'}</li>
{/each}
{#if exportedConversations.length > 10}
<li class="italic">
... and {exportedConversations.length - 10} more
</li>
{/if}
</ul>
</div>
{/if}
</div>
<div class="grid border-t border-border/30 pt-4">
<h4 class="mb-2 text-sm font-medium">Import Conversations</h4>
<p class="mb-4 text-sm text-muted-foreground">
Import one or more conversations from a previously exported JSON file. This will merge with
your existing conversations.
</p>
<Button
class="w-full justify-start justify-self-start md:w-auto"
onclick={handleImportClick}
variant="outline"
>
<Upload class="mr-2 h-4 w-4" />
Import conversations
</Button>
{#if showImportSummary && importedConversations.length > 0}
<div class="mt-4 grid overflow-x-auto rounded-lg border border-border/50 bg-muted/30 p-4">
<h5 class="mb-2 text-sm font-medium">
Imported {importedConversations.length} conversation{importedConversations.length === 1
? ''
: 's'}
</h5>
<ul class="space-y-1 text-sm text-muted-foreground">
{#each importedConversations.slice(0, 10) as conv (conv.id)}
<li class="truncate">{conv.name || 'Untitled conversation'}</li>
{/each}
{#if importedConversations.length > 10}
<li class="italic">
... and {importedConversations.length - 10} more
</li>
{/if}
</ul>
</div>
{/if}
</div>
</div>
</div>
<ConversationSelectionDialog
conversations={availableConversations}
{messageCountMap}
mode="export"
bind:open={showExportDialog}
onCancel={() => (showExportDialog = false)}
onConfirm={handleExportConfirm}
/>
<ConversationSelectionDialog
conversations={availableConversations}
{messageCountMap}
mode="import"
bind:open={showImportDialog}
onCancel={() => (showImportDialog = false)}
onConfirm={handleImportConfirm}
/>
@@ -1,9 +1,8 @@
<script lang="ts">
import { Search, SquarePen, X, Download, Upload } from '@lucide/svelte';
import { Search, SquarePen, X } from '@lucide/svelte';
import { KeyboardShortcutInfo } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { Input } from '$lib/components/ui/input';
import { exportAllConversations, importConversations } from '$lib/stores/chat.svelte';
interface Props {
handleMobileSidebarItemClick: () => void;
@@ -78,34 +77,5 @@
<KeyboardShortcutInfo keys={['cmd', 'k']} />
</Button>
<Button
class="w-full justify-start text-sm"
onclick={() => {
importConversations().catch((err) => {
console.error('Import failed:', err);
// Optional: show toast or dialog
});
}}
variant="ghost"
>
<div class="flex items-center gap-2">
<Upload class="h-4 w-4" />
Import conversations
</div>
</Button>
<Button
class="w-full justify-start text-sm"
onclick={() => {
exportAllConversations();
}}
variant="ghost"
>
<div class="flex items-center gap-2">
<Download class="h-4 w-4" />
Export all conversations
</div>
</Button>
{/if}
</div>
@@ -1,7 +1,7 @@
<script lang="ts">
import { Trash2, Pencil, MoreHorizontal, Download } from '@lucide/svelte';
import { Trash2, Pencil, MoreHorizontal, Download, Loader2 } from '@lucide/svelte';
import { ActionDropdown } from '$lib/components/app';
import { downloadConversation } from '$lib/stores/chat.svelte';
import { downloadConversation, getAllLoadingConversations } from '$lib/stores/chat.svelte';
import { onMount } from 'svelte';
interface Props {
@@ -25,6 +25,8 @@
let renderActionsDropdown = $state(false);
let dropdownOpen = $state(false);
let isLoading = $derived(getAllLoadingConversations().includes(conversation.id));
function handleEdit(event: Event) {
event.stopPropagation();
onEdit?.(conversation.id);
@@ -83,11 +85,16 @@
onmouseover={handleMouseOver}
onmouseleave={handleMouseLeave}
>
<!-- svelte-ignore a11y_click_events_have_key_events -->
<!-- svelte-ignore a11y_no_static_element_interactions -->
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
{conversation.name}
</span>
<div class="flex min-w-0 flex-1 items-center gap-2">
{#if isLoading}
<Loader2 class="h-3.5 w-3.5 shrink-0 animate-spin text-muted-foreground" />
{/if}
<!-- svelte-ignore a11y_click_events_have_key_events -->
<!-- svelte-ignore a11y_no_static_element_interactions -->
<span class="truncate text-sm font-medium" onclick={handleMobileSidebarItemClick}>
{conversation.name}
</span>
</div>
{#if renderActionsDropdown}
<div class="actions flex items-center">
@@ -8,6 +8,7 @@ export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.sv
export { default as ChatFormActions } from './chat/ChatForm/ChatFormActions.svelte';
export { default as ChatFormActionFileAttachments } from './chat/ChatForm/ChatFormActionFileAttachments.svelte';
export { default as ChatFormActionRecord } from './chat/ChatForm/ChatFormActionRecord.svelte';
export { default as ChatFormModelSelector } from './chat/ChatForm/ChatFormModelSelector.svelte';
export { default as ChatFormHelperText } from './chat/ChatForm/ChatFormHelperText.svelte';
export { default as ChatFormFileInputInvisible } from './chat/ChatForm/ChatFormFileInputInvisible.svelte';
@@ -25,12 +26,13 @@ export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte';
export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte';
export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte';
export { default as ImportExportTab } from './chat/ChatSettings/ImportExportTab.svelte';
export { default as ConversationSelectionDialog } from './chat/ChatSettings/ConversationSelectionDialog.svelte';
export { default as ParameterSourceIndicator } from './chat/ChatSettings/ParameterSourceIndicator.svelte';
export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte';
export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte';
@@ -154,9 +154,20 @@
return mutated ? tempDiv.innerHTML : html;
}
function normalizeMathDelimiters(text: string): string {
return text
.replace(/(^|[^\\])\\\[((?:\\.|[\s\S])*?)\\\]/g, (_, prefix: string, content: string) => {
return `${prefix}$$${content}$$`;
})
.replace(/(^|[^\\])\\\(((?:\\.|[\s\S])*?)\\\)/g, (_, prefix: string, content: string) => {
return `${prefix}$${content}$`;
});
}
async function processMarkdown(text: string): Promise<string> {
try {
const result = await processor().process(text);
const normalized = normalizeMathDelimiters(text);
const result = await processor().process(normalized);
const html = String(result);
const enhancedLinks = enhanceLinks(html);
@@ -8,22 +8,33 @@
class: className,
children,
size = 'default',
variant = 'default',
...restProps
}: WithoutChild<SelectPrimitive.TriggerProps> & {
size?: 'sm' | 'default';
variant?: 'default' | 'plain';
} = $props();
const baseClasses = $derived(
variant === 'plain'
? "group inline-flex w-full items-center justify-end gap-2 whitespace-nowrap px-0 py-0 text-sm font-medium text-muted-foreground transition-colors focus-visible:outline-none focus-visible:ring-0 focus-visible:ring-offset-0 disabled:cursor-not-allowed disabled:opacity-50 data-[placeholder]:text-muted-foreground data-[size=default]:h-9 data-[size=sm]:h-8 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-3 [&_svg:not([class*='text-'])]:text-muted-foreground"
: "flex w-fit items-center justify-between gap-2 rounded-md border border-input bg-transparent px-3 py-2 text-sm whitespace-nowrap shadow-xs transition-[color,box-shadow] outline-none select-none focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50 disabled:cursor-not-allowed disabled:opacity-50 aria-invalid:border-destructive aria-invalid:ring-destructive/20 data-[placeholder]:text-muted-foreground data-[size=default]:h-9 data-[size=sm]:h-8 *:data-[slot=select-value]:line-clamp-1 *:data-[slot=select-value]:flex *:data-[slot=select-value]:items-center *:data-[slot=select-value]:gap-2 dark:bg-input/30 dark:hover:bg-input/50 dark:aria-invalid:ring-destructive/40 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4 [&_svg:not([class*='text-'])]:text-muted-foreground"
);
const chevronClasses = $derived(
variant === 'plain'
? 'size-3 opacity-60 transition-transform group-data-[state=open]:-rotate-180'
: 'size-4 opacity-50'
);
</script>
<SelectPrimitive.Trigger
bind:ref
data-slot="select-trigger"
data-size={size}
class={cn(
"flex w-fit items-center justify-between gap-2 rounded-md border border-input bg-transparent px-3 py-2 text-sm whitespace-nowrap shadow-xs transition-[color,box-shadow] outline-none select-none focus-visible:border-ring focus-visible:ring-[3px] focus-visible:ring-ring/50 disabled:cursor-not-allowed disabled:opacity-50 aria-invalid:border-destructive aria-invalid:ring-destructive/20 data-[placeholder]:text-muted-foreground data-[size=default]:h-9 data-[size=sm]:h-8 *:data-[slot=select-value]:line-clamp-1 *:data-[slot=select-value]:flex *:data-[slot=select-value]:items-center *:data-[slot=select-value]:gap-2 dark:bg-input/30 dark:hover:bg-input/50 dark:aria-invalid:ring-destructive/40 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4 [&_svg:not([class*='text-'])]:text-muted-foreground",
className
)}
class={cn(baseClasses, className)}
{...restProps}
>
{@render children?.()}
<ChevronDownIcon class="size-4 opacity-50" />
<ChevronDownIcon class={chevronClasses} />
</SelectPrimitive.Trigger>
@@ -1 +1,2 @@
export const SERVER_PROPS_LOCALSTORAGE_KEY = 'LlamaCppWebui.serverProps';
export const SELECTED_MODEL_LOCALSTORAGE_KEY = 'LlamaCppWebui.selectedModel';
@@ -13,6 +13,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
pdfAsImage: false,
showModelInfo: false,
renderUserContentAsMarkdown: false,
modelSelectorEnabled: false,
// make sure these default values are in sync with `common.h`
samplers: 'top_k;typ_p;top_p;min_p;temperature',
temperature: 0.8,
@@ -86,6 +87,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).',
showModelInfo: 'Display the model name used to generate each message below the message content.',
renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.',
modelSelectorEnabled:
'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.',
pyInterpreterEnabled:
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.'
};
+141 -43
View File
@@ -1,4 +1,5 @@
import { config } from '$lib/stores/settings.svelte';
import { selectedModelName } from '$lib/stores/models.svelte';
import { slotsService } from './slots';
/**
* ChatService - Low-level API communication layer for llama.cpp server interactions
@@ -29,7 +30,7 @@ import { slotsService } from './slots';
* - Request lifecycle management (abort, cleanup)
*/
export class ChatService {
private abortController: AbortController | null = null;
private abortControllers: Map<string, AbortController> = new Map();
/**
* Sends a chat completion request to the llama.cpp server.
@@ -43,13 +44,16 @@ export class ChatService {
*/
async sendMessage(
messages: ApiChatMessageData[] | (DatabaseMessage & { extra?: DatabaseMessageExtra[] })[],
options: SettingsChatServiceOptions = {}
options: SettingsChatServiceOptions = {},
conversationId?: string
): Promise<string | void> {
const {
stream,
onChunk,
onComplete,
onError,
onReasoningChunk,
onModel,
// Generation parameters
temperature,
max_tokens,
@@ -79,25 +83,25 @@ export class ChatService {
const currentConfig = config();
// Cancel any ongoing request and create a new abort controller
this.abort();
this.abortController = new AbortController();
const requestId = conversationId || 'default';
if (this.abortControllers.has(requestId)) {
this.abortControllers.get(requestId)?.abort();
}
const abortController = new AbortController();
this.abortControllers.set(requestId, abortController);
// Convert database messages with attachments to API format if needed
const normalizedMessages: ApiChatMessageData[] = messages
.map((msg) => {
// Check if this is a DatabaseMessage by checking for DatabaseMessage-specific fields
if ('id' in msg && 'convId' in msg && 'timestamp' in msg) {
// This is a DatabaseMessage, convert it
const dbMsg = msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] };
return ChatService.convertMessageToChatServiceData(dbMsg);
} else {
// This is already an ApiChatMessageData object
return msg as ApiChatMessageData;
}
})
.filter((msg) => {
// Filter out empty system messages
if (msg.role === 'system') {
const content = typeof msg.content === 'string' ? msg.content : '';
@@ -107,7 +111,6 @@ export class ChatService {
return true;
});
// Build base request body with system message injection
const processedMessages = this.injectSystemMessage(normalizedMessages);
const requestBody: ApiChatCompletionRequest = {
@@ -118,6 +121,13 @@ export class ChatService {
stream
};
const modelSelectorEnabled = Boolean(currentConfig.modelSelectorEnabled);
const activeModel = modelSelectorEnabled ? selectedModelName() : null;
if (modelSelectorEnabled && activeModel) {
requestBody.model = activeModel;
}
requestBody.reasoning_format = currentConfig.disableReasoningFormat ? 'none' : 'auto';
if (temperature !== undefined) requestBody.temperature = temperature;
@@ -172,11 +182,10 @@ export class ChatService {
...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {})
},
body: JSON.stringify(requestBody),
signal: this.abortController.signal
signal: abortController.signal
});
if (!response.ok) {
// Use the new parseErrorResponse method to handle structured errors
const error = await this.parseErrorResponse(response);
if (onError) {
onError(error);
@@ -185,15 +194,19 @@ export class ChatService {
}
if (stream) {
return this.handleStreamResponse(
await this.handleStreamResponse(
response,
onChunk,
onComplete,
onError,
options.onReasoningChunk
onReasoningChunk,
onModel,
conversationId,
abortController.signal
);
return;
} else {
return this.handleNonStreamResponse(response, onComplete, onError);
return this.handleNonStreamResponse(response, onComplete, onError, onModel);
}
} catch (error) {
if (error instanceof Error && error.name === 'AbortError') {
@@ -227,18 +240,19 @@ export class ChatService {
onError(userFriendlyError);
}
throw userFriendlyError;
} finally {
this.abortControllers.delete(requestId);
}
}
/**
* Handles streaming response from the chat completion API.
* Processes server-sent events and extracts content chunks from the stream.
*
* @param response - The fetch Response object containing the streaming data
* Handles streaming response from the chat completion API
* @param response - The Response object from the fetch request
* @param onChunk - Optional callback invoked for each content chunk received
* @param onComplete - Optional callback invoked when the stream is complete with full response
* @param onError - Optional callback invoked if an error occurs during streaming
* @param onReasoningChunk - Optional callback invoked for each reasoning content chunk
* @param conversationId - Optional conversation ID for per-conversation state tracking
* @returns {Promise<void>} Promise that resolves when streaming is complete
* @throws {Error} if the stream cannot be read or parsed
*/
@@ -251,7 +265,10 @@ export class ChatService {
timings?: ChatMessageTimings
) => void,
onError?: (error: Error) => void,
onReasoningChunk?: (chunk: string) => void
onReasoningChunk?: (chunk: string) => void,
onModel?: (model: string) => void,
conversationId?: string,
abortSignal?: AbortSignal
): Promise<void> {
const reader = response.body?.getReader();
@@ -265,18 +282,25 @@ export class ChatService {
let hasReceivedData = false;
let lastTimings: ChatMessageTimings | undefined;
let streamFinished = false;
let modelEmitted = false;
try {
let chunk = '';
while (true) {
if (abortSignal?.aborted) break;
const { done, value } = await reader.read();
if (done) break;
if (abortSignal?.aborted) break;
chunk += decoder.decode(value, { stream: true });
const lines = chunk.split('\n');
chunk = lines.pop() || ''; // Save incomplete line for next read
chunk = lines.pop() || '';
for (const line of lines) {
if (abortSignal?.aborted) break;
if (line.startsWith('data: ')) {
const data = line.slice(6);
if (data === '[DONE]') {
@@ -287,15 +311,19 @@ export class ChatService {
try {
const parsed: ApiChatCompletionStreamChunk = JSON.parse(data);
const chunkModel = this.extractModelName(parsed);
if (chunkModel && !modelEmitted) {
modelEmitted = true;
onModel?.(chunkModel);
}
const content = parsed.choices[0]?.delta?.content;
const reasoningContent = parsed.choices[0]?.delta?.reasoning_content;
const timings = parsed.timings;
const promptProgress = parsed.prompt_progress;
if (timings || promptProgress) {
this.updateProcessingState(timings, promptProgress);
// Store the latest timing data
this.updateProcessingState(timings, promptProgress, conversationId);
if (timings) {
lastTimings = timings;
}
@@ -304,21 +332,29 @@ export class ChatService {
if (content) {
hasReceivedData = true;
aggregatedContent += content;
onChunk?.(content);
if (!abortSignal?.aborted) {
onChunk?.(content);
}
}
if (reasoningContent) {
hasReceivedData = true;
fullReasoningContent += reasoningContent;
onReasoningChunk?.(reasoningContent);
if (!abortSignal?.aborted) {
onReasoningChunk?.(reasoningContent);
}
}
} catch (e) {
console.error('Error parsing JSON chunk:', e);
}
}
}
if (abortSignal?.aborted) break;
}
if (abortSignal?.aborted) return;
if (streamFinished) {
if (!hasReceivedData && aggregatedContent.length === 0) {
const noResponseError = new Error('No response received from server. Please try again.');
@@ -355,7 +391,8 @@ export class ChatService {
reasoningContent?: string,
timings?: ChatMessageTimings
) => void,
onError?: (error: Error) => void
onError?: (error: Error) => void,
onModel?: (model: string) => void
): Promise<string> {
try {
const responseText = await response.text();
@@ -366,6 +403,12 @@ export class ChatService {
}
const data: ApiChatCompletionResponse = JSON.parse(responseText);
const responseModel = this.extractModelName(data);
if (responseModel) {
onModel?.(responseModel);
}
const content = data.choices[0]?.message?.content || '';
const reasoningContent = data.choices[0]?.message?.reasoning_content;
@@ -445,6 +488,19 @@ export class ChatService {
});
}
// Handle legacy 'context' type from old webui (pasted content)
const legacyContextFiles = message.extra.filter(
(extra: DatabaseMessageExtra): extra is DatabaseMessageExtraLegacyContext =>
extra.type === 'context'
);
for (const legacyContextFile of legacyContextFiles) {
contentParts.push({
type: 'text',
text: `\n\n--- File: ${legacyContextFile.name} ---\n${legacyContextFile.content}`
});
}
const audioFiles = message.extra.filter(
(extra: DatabaseMessageExtra): extra is DatabaseMessageExtraAudioFile =>
extra.type === 'audioFile'
@@ -520,10 +576,18 @@ export class ChatService {
*
* @public
*/
public abort(): void {
if (this.abortController) {
this.abortController.abort();
this.abortController = null;
public abort(conversationId?: string): void {
if (conversationId) {
const abortController = this.abortControllers.get(conversationId);
if (abortController) {
abortController.abort();
this.abortControllers.delete(conversationId);
}
} else {
for (const controller of this.abortControllers.values()) {
controller.abort();
}
this.abortControllers.clear();
}
}
@@ -581,32 +645,66 @@ export class ChatService {
return error;
} catch {
// If we can't parse the error response, return a generic error
const fallback = new Error(`Server error (${response.status}): ${response.statusText}`);
fallback.name = 'HttpError';
return fallback;
}
}
private extractModelName(data: unknown): string | undefined {
const asRecord = (value: unknown): Record<string, unknown> | undefined => {
return typeof value === 'object' && value !== null
? (value as Record<string, unknown>)
: undefined;
};
const getTrimmedString = (value: unknown): string | undefined => {
return typeof value === 'string' && value.trim() ? value.trim() : undefined;
};
const root = asRecord(data);
if (!root) return undefined;
// 1) root (some implementations provide `model` at the top level)
const rootModel = getTrimmedString(root.model);
if (rootModel) return rootModel;
// 2) streaming choice (delta) or final response (message)
const firstChoice = Array.isArray(root.choices) ? asRecord(root.choices[0]) : undefined;
if (!firstChoice) return undefined;
// priority: delta.model (first chunk) else message.model (final response)
const deltaModel = getTrimmedString(asRecord(firstChoice.delta)?.model);
if (deltaModel) return deltaModel;
const messageModel = getTrimmedString(asRecord(firstChoice.message)?.model);
if (messageModel) return messageModel;
// avoid guessing from non-standard locations (metadata, etc.)
return undefined;
}
private updateProcessingState(
timings?: ChatMessageTimings,
promptProgress?: ChatMessagePromptProgress
promptProgress?: ChatMessagePromptProgress,
conversationId?: string
): void {
// Calculate tokens per second from timing data
const tokensPerSecond =
timings?.predicted_ms && timings?.predicted_n
? (timings.predicted_n / timings.predicted_ms) * 1000
: 0;
// Update slots service with timing data (async but don't wait)
slotsService
.updateFromTimingData({
prompt_n: timings?.prompt_n || 0,
predicted_n: timings?.predicted_n || 0,
predicted_per_second: tokensPerSecond,
cache_n: timings?.cache_n || 0,
prompt_progress: promptProgress
})
.updateFromTimingData(
{
prompt_n: timings?.prompt_n || 0,
predicted_n: timings?.predicted_n || 0,
predicted_per_second: tokensPerSecond,
cache_n: timings?.cache_n || 0,
prompt_progress: promptProgress
},
conversationId
)
.catch((error) => {
console.warn('Failed to update processing state:', error);
});
@@ -0,0 +1,22 @@
import { base } from '$app/paths';
import { config } from '$lib/stores/settings.svelte';
import type { ApiModelListResponse } from '$lib/types/api';
export class ModelsService {
static async list(): Promise<ApiModelListResponse> {
const currentConfig = config();
const apiKey = currentConfig.apiKey?.toString().trim();
const response = await fetch(`${base}/v1/models`, {
headers: {
...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {})
}
});
if (!response.ok) {
throw new Error(`Failed to fetch model list (status ${response.status})`);
}
return response.json() as Promise<ApiModelListResponse>;
}
}
+85 -17
View File
@@ -37,6 +37,8 @@ export class SlotsService {
private callbacks: Set<(state: ApiProcessingState | null) => void> = new Set();
private isStreamingActive: boolean = false;
private lastKnownState: ApiProcessingState | null = null;
private conversationStates: Map<string, ApiProcessingState | null> = new Map();
private activeConversationId: string | null = null;
/**
* Start streaming session tracking
@@ -75,6 +77,62 @@ export class SlotsService {
return this.isStreamingActive;
}
/**
* Set the active conversation for statistics display
*/
setActiveConversation(conversationId: string | null): void {
this.activeConversationId = conversationId;
this.notifyCallbacks();
}
/**
* Update processing state for a specific conversation
*/
updateConversationState(conversationId: string, state: ApiProcessingState | null): void {
this.conversationStates.set(conversationId, state);
if (conversationId === this.activeConversationId) {
this.lastKnownState = state;
this.notifyCallbacks();
}
}
/**
* Get processing state for a specific conversation
*/
getConversationState(conversationId: string): ApiProcessingState | null {
return this.conversationStates.get(conversationId) || null;
}
/**
* Clear state for a specific conversation
*/
clearConversationState(conversationId: string): void {
this.conversationStates.delete(conversationId);
if (conversationId === this.activeConversationId) {
this.lastKnownState = null;
this.notifyCallbacks();
}
}
/**
* Notify all callbacks with current state
*/
private notifyCallbacks(): void {
const currentState = this.activeConversationId
? this.conversationStates.get(this.activeConversationId) || null
: this.lastKnownState;
for (const callback of this.callbacks) {
try {
callback(currentState);
} catch (error) {
console.error('Error in slots service callback:', error);
}
}
}
/**
* @deprecated Polling is no longer used - timing data comes from ChatService streaming response
* This method logs a warning if called to help identify outdated usage
@@ -100,29 +158,29 @@ export class SlotsService {
/**
* Updates processing state with timing data from ChatService streaming response
*/
async updateFromTimingData(timingData: {
prompt_n: number;
predicted_n: number;
predicted_per_second: number;
cache_n: number;
prompt_progress?: ChatMessagePromptProgress;
}): Promise<void> {
async updateFromTimingData(
timingData: {
prompt_n: number;
predicted_n: number;
predicted_per_second: number;
cache_n: number;
prompt_progress?: ChatMessagePromptProgress;
},
conversationId?: string
): Promise<void> {
const processingState = await this.parseCompletionTimingData(timingData);
// Only update if we successfully parsed the state
if (processingState === null) {
console.warn('Failed to parse timing data - skipping update');
return;
}
this.lastKnownState = processingState;
for (const callback of this.callbacks) {
try {
callback(processingState);
} catch (error) {
console.error('Error in timing callback:', error);
}
if (conversationId) {
this.updateConversationState(conversationId, processingState);
} else {
this.lastKnownState = processingState;
this.notifyCallbacks();
}
}
@@ -143,6 +201,7 @@ export class SlotsService {
...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {})
}
});
if (response.ok) {
const slotsData = await response.json();
if (Array.isArray(slotsData) && slotsData.length > 0) {
@@ -179,6 +238,7 @@ export class SlotsService {
if (contextTotal === null) {
console.warn('No context total available - cannot calculate processing state');
return null;
}
@@ -214,13 +274,21 @@ export class SlotsService {
/**
* Get current processing state
* Returns the last known state from timing data, or null if no data available
* If activeConversationId is set, returns state for that conversation
*/
async getCurrentState(): Promise<ApiProcessingState | null> {
if (this.activeConversationId) {
const conversationState = this.conversationStates.get(this.activeConversationId);
if (conversationState) {
return conversationState;
}
}
if (this.lastKnownState) {
return this.lastKnownState;
}
try {
// Import dynamically to avoid circular dependency
const { chatStore } = await import('$lib/stores/chat.svelte');
const messages = chatStore.activeMessages;

Some files were not shown because too many files have changed in this diff Show More