Compare commits

..

27 Commits

Author SHA1 Message Date
Masato Nakasaka 7b95ea5d11 common: Intentionally leak logger instance to fix hanging on Windows (#22273)
* Changed to leak logger singleton to prevent hanging on Windows

* Fix comment

* Stopped using static vector

Using std::vector will cause g_col to be released before the logger thread exits, causing the logger thread to touch freed memory causing a crash

* Change so all logs are output before exit

* Added debug logging

* added more logging

* Added logging

* Explicitly free logger to avoid hanging on Win

* Reverted to leak logger instance again

* Removed debug log and fixed comment

* Fixed comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-04-29 10:58:43 +03:00
hrushitfujitsu bdc9c743a5 ggml : add sve tuned code for gemm_q8_0_4x8_q8_0() kernel (#21916)
* Added sve tuned code for gemm_q8_0_4x8_q8_0() kernel

* Change arrays to static const in repack.cpp

---------

Co-authored-by: Vithulep <prashant.vithule@fujitsu.com>
2026-04-29 10:57:37 +03:00
Johannes Gäßler 739393beeb TP: fix delayed AllReduce + zero-sized slices (#22489) 2026-04-29 08:55:07 +02:00
Michael Wand fc2b0053ff ggml-cuda: Repost of 21896: Blackwell native NVFP4 support (#22196) 2026-04-29 06:47:42 +08:00
lnigam 7b8443ac78 ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (… (#22286)
* ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (GQA=32)

Adds MMA-f16 and tile kernel configs, dispatch logic, template instances,
and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to
ncols2=32 to support GQA ratio 32 only.

* Adding check to return BEST_FATTN_KERNEL_NONE in case GQA!=32

* Apply suggestions from code review

Address review comments

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Address review comments and making kernel config default to DQK=512, DV=512 instead of DQK=256,DV=256

* Fixed bug with sinks=1, with ncols=32, there are two warp-groups created but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np)

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-04-28 21:37:35 +02:00
Daniel Bevenius 5d56effdee convert : add support for Nemotron Nano 3 Omni (#22481)
This commit adds support for NVIDIA Nemotron Nano 3 Omni model enabling
this model to be converted to GGUF.
2026-04-28 19:17:57 +02:00
Jillis ter Hove 52e5f0a5c1 common : re-arm reasoning budget after DONE on new <think> (#22323)
DONE state absorbs all tokens including a new start tag, causing any think blocks after the first to run unbudgeted. Observed on unsloth/Qwen3.6-27B-GGUF which interleaves multiple <think> blocks per response.

Fixed by advancing start_matcher in DONE branch and re-arming to COUNTING with a fresh budget on match. Adds regression test (test-reasoning-budget: test 6).
2026-04-28 19:15:36 +02:00
Matt Corallo f9f33654a6 vulkan: Coalesce Q4_K/Q5_K scale loads (#21751)
Some SPIR-V compilers (notably mesa) don't handle the current
vulkan Q4_K/Q5_K scale load pattern in mul_mat particularly well.
While reading three `u8`s from the 12-byte scale array should (at
least on some hardware) result in loading the full 12 bytes in a
single LOAD followed by whatever extraction is needed, at least
the ANV Intel driver really can't practically perform this
optimization.

`mesa`'s unsigned upper bound logic doesn't handle tracking bounds
through ternary, resulting in the `(is < 4) ? ... : is - 4` having
an infinite upper bound (as it cannot prove `is - 4` doesn't
underflow). While this could still be rectified if mesa looked at
the array bounds, it currently doesn't and `glslc` currently emits
SPIR-V that doesn't allow for this optimization anyway (though
maybe it will at some point, see
https://github.com/KhronosGroup/glslang/issues/4206).

In mul_mat_vecq we took a different approach to loading the same
fields. We read the first two bytes we needed from `scale` then
took a branch before deciding whether we needed to read a third
byte. In mesa this did, indeed, lead to a top-level branch with
conditional loads. As such these loads ended up not being
coalesced either (at least in the ANV driver) resulting in
additional instructions in our hot loop.

Instead, here, we go ahead and force loading the full 12 bytes and
extract the bits we need from the packed-u32s instead. In mul_mat
there's a few less ternaries and only one extra shift, so even on
drivers that did optimize the previous loads properly the only
material change should be pulling a few extra bytes into registers
(which on most hardware won't cost anything anyway, though
ironically on Intel it theoretically could). In mul_mat_vecq this
requires a bit of extra math and may read bytes from the u32 that
weren't needed, but it seems likely avoiding the branch is a win
on most platforms.

On Intel Xe2/mesa 26.0.4 with the optimizations from
https://gitlab.freedesktop.org/mesa/mesa/-/work_items/15162,

for shader matmul_id_subgroup_q4_k_f32_f16acc_aligned_l:
 * Instruction Count: 2753 -> 2688
 * SEND Count: 269 -> 261
 * Cycle Count: 273976 -> 266138
 * Max live registers: 248 -> 246
 * Non SSA regs after NIR: 381 -> 382

for shader matmul_id_subgroup_q5_k_f32_f16acc_aligned_l:
 * Instruction Count: 2767 -> 2702
 * SEND Count: 271 -> 263
 * Cycle Count: 274140 -> 268144
 * Max live registers: 248 -> 246
 * Non SSA regs after NIR: 381 -> 382

for shader mul_mat_vec_id_q4_k_q8_1_f32:
 * Instruction Count: 1930 -> 1646
 * SEND Count: 116 -> 71
 * Cycle Count: 1348306 -> 843350
 * Max live registers: 78 -> 84
 * Non SSA regs after NIR: 300 -> 135

for shader mul_mat_vec_id_q5_k_q8_1_f32:
 * Instruction Count: 2207 -> 1922
 * SEND Count: 131 -> 86
 * Cycle Count: 1392012 -> 1037836
 * Max live registers: 90 -> 90
 * Non SSA regs after NIR: 300 -> 135

for shader mul_mat_vec_q4_k_q8_1_f32:
 * Instruction Count: 2029 -> 1749
 * SEND Count: 111 -> 66
 * Cycle Count: 1347278 -> 840118
 * Max live registers: 74 -> 80
 * Non SSA regs after NIR: 299 -> 134

for shader mul_mat_vec_q5_k_q8_1_f32:
 * Instruction Count: 2307 -> 2022
 * SEND Count: 126 -> 81
 * Cycle Count: 1379820 -> 954042
 * Max live registers: 86 -> 86
 * Non SSA regs after NIR: 299 -> 134

On one Arc Pro B60, unsloth/Qwen3.5-35B-A3B-GGUF:UD-Q4_K_XL:
 * pp512: 907.34 ± 9.28 -> 941.94 ± 10.53 (+4%)
 * pp2048: 897.95 ± 1.82 -> 931.55 ± 1.79 (+4%)
 * tg128: 49.49 ± 0.02 -> 49.86 ± 0.05 (+ <1%)

On one Arc Pro B60, unsloth/Qwen3.5-27B-GGUF:Q4_K_S:
 * pp512: 324.13 ± 10.52 -> 354.33 ± 6.81 (+9%)
 * pp2048: 329.80 ± 0.25 -> 357.10 ± 0.06 (+8%)
 * tg128: 17.11 ± 0.01 -> 18.11 ± 0.01 (+6%)

On four Arc Pro B60s, unsloth/Qwen3.5-122B-A10B-GGUF:Q5_K_S with
-sm layer (note that -sm tensor improvements will naturally be
less):
 * pp512: 264.55 ± 2.81 -> 280.45 ± 3.94 (+6%)
 * pp2048: 319.32 ± 2.72 -> 335.70 ± 3.48 (+5%)
 * tg128: 26.39 ± 0.01 -> 26.67 ± 0.01 (+1%)
2026-04-28 17:31:04 +02:00
Reese Levine 98bb57916a ggml-webgpu: fix buffer aliasing for ssm_scan and refactor aliasing logic (#22456)
* Refactor buffer aliasing to be part of shader lib decisions

* cleanup

* formatting
2026-04-28 07:27:17 -07:00
Aleksander Grygier f42e29fdf1 webui: Server tools (#21237)
* wip: server_tools

* feat: Integrate with `/tools` endpoint

* feat: Builtin + MCP + JSON Schema Tools WIP

* refactor

* displayName -> display_name

* snake_case everywhere

* rm redundant field

* feat: Improvements

* chore: update webui build output

* refactor: Updates after server updates

* chore: update webui build output

* change arg to --tools all

* feat: UI improvements

* chore: update webui build output

* add readme mention

* llama-gen-docs

* chore: update webui build output

* chore: update webui build output

* chore: update webui build output

* feat: Reorganize settings sections

* feat: Separate dialogs for MCP Servers Settings and Import/Export

* feat: WIP

* feat: WIP

* feat: WIP

* feat: WIP

* feat: WIP

* feat: WIP

* WIP on allozaur/20677-webui-server-tools

* feat: UI improvements

* chore: Update package lock

* chore: Run `npm audit fix`

* feat: UI WIP

* feat: UI

* refactor: Desktop Icon Strip DRY

* feat: Cleaner rendering and transition for ChatScreen

* feat: UI improvements

* feat: UI improvement

* feat: Remove MCP Server "enable" switch from Tools submenu

* chore: Run `npm audit fix`

* feat: WIP

* feat: Logic improvements

* refactor: Cleanup

* refactor: DRY

* test: Fix Chat Sidebar UI Tests

* chore: Update package lock

* refactor: Cleanup

* feat: Chat Message Action Card with Continue and Permission flow implementations

* feat: Add agentic steering messages, draft messages and improve chat UX

* fix: Search results UI

* test: Fix unit test

* feat: UI/UX improvements

* refactor: Simplify `useToolsPanel` access in components

* feat: Implement Processing Info Context API

* feat: Implement 'Go back to chat' functionality for settings

* feat: Enhance MCP Server management in Chat Form Attachments

* style: Minor UI and branding adjustments

* chore: Update webui static build output

* chore: Formatting, linting & type checks

* feat: Draft messages logic

* feat: UI improvements

* feat: Steering Messages improvements

* refactor: Cleanup

* refactor: Cleanup

* feat: Improve UI

* refactor: Settings navigation hook

* refactor: DRY code

* refactor: DRY ChatMessageUser UI components

* refactor: Desktop Icon Strip DRY

* refactor: Tools & permissions

* fix: Navigation condition

* refactor: Cleanup

* refactor: Cleanup

* refactor: Cleanup

* fix: preserve reasoning_content in agentic flow

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-04-28 14:35:49 +03:00
Jeff Bolz 19821178be vulkan: add barrier after writetimestamp (#21865) 2026-04-28 12:28:12 +02:00
Emil Askerov 698d19b93c ggml: improve SPIR-V headers detection with __has_include (#21918)
* ggml: improve SPIR-V headers detection with __has_include while preserving original _WIN32 logic

* Address review comments: fix fallback logic and add FreeBSD support

* Remove spirv_cross fallback as per review

* Remove redundant __has_include check
2026-04-28 12:19:06 +02:00
Adrien Gallouët 50494a2800 ggml : skip already registered backends and devices (#22296)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-28 10:02:32 +03:00
Adrien Gallouët d530d6e7a2 ggml : revert to -lm linking instead of find_library (#22355)
* ggml : revert to -lm linking instead of find_library

`find_library(MATH_LIBRARY m)` was introduced recently, but it breaks
CUDA compilation with GGML_STATIC. I could not find any valid use case
where we would prefer `find_library` over the standard `-lm` approach.

This commit is also meant to start a discussion if there is a valid
reason to keep `find_library(MATH_LIBRARY m)`, we should clarify what
problem it was solving and find an alternative fix that does not break
CUDA with GGML_STATIC.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* ggml : use MATH_LIBRARY only if defined

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* ggml : fix initial broken condition

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* ggml : always respect MATH_LIBRARY when defined

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-28 09:56:02 +03:00
hipudding c3e08f4700 CANN: add new ops, optimize existing ops (#21204)
New operators:
- GGML_OP_SET: implement via aclnnInplaceCopy on target region
- GGML_OP_CUMSUM: implement via aclnnCumsum
- GGML_OP_FILL: implement via aclnnInplaceFillScalar
- GGML_OP_DIAG: implement via aclnnInplaceCopy on diagonal strides
- GGML_OP_TRI (lower/lower_diag/upper_diag/upper): implement via
  aclnnTril(-1/0) and aclnnTriu(0/1) with appropriate diagonal offsets
- GGML_OP_SOLVE_TRI: implement via aclnnTriangularSolve
- GGML_UNARY_OP_SOFTPLUS: implement via aclnnSoftplus

Optimizations:
- GLU (SwiGLU/GeGLU/GeGLU_ERF/GeGLU_QUICK): fuse with aclnnSwiGlu /
  aclnnGeGluV3 when applicable; fallback conditions now checked inside
  each function rather than at the call site
- CROSS_ENTROPY_LOSS: replace 5-kernel sequence (LogSoftmax→Mul→
  ReduceSum×2→Muls) with single aclnnSoftmaxCrossEntropyWithLogits call
- L2_NORM: fix in-place ClampMin on norm result (was clamping wrong
  tensor); add eps clamping before division to avoid divide-by-zero
- PAD_REFLECT_1D: eliminate per-ne[3] loop; assert contiguity and call
  ReflectionPad1d once on the full 4-D view; remove redundant nb copies
- GET_ROWS: replace IndexSelect with GatherV2 per batch slice; refactor
  helper into gather_batched lambda with batch loop inlined
- SET_ROWS: replace IndexCopy with InplaceIndexCopy per batch slice;
  refactor helper into scatter_batched lambda with batch loop inlined
- OUT_PROD: replace O(ne[3]*ne[2]*ne[1]) Ger+InplaceAdd loop with
  per-slice Matmul loop (src0 @ src1^T); handles strided-broadcast
  batch dims where ne02/ne03 may differ from ne2/ne3
- backend memset_tensor: implement via aclrtMemset (was NULL)

Bug fixes:
- COUNT_EQUAL: use non-inplace EqTensor into a same-type temporary
  buffer instead of InplaceEqTensor, avoiding corruption of src0
- ACL graph cache (USE_ACL_GRAPH): restore node_type and src_type[]
  fields in ggml_graph_node_properties; has_matching_properties() was
  missing type checks, causing F16 and BF16 tensors (same nb[0]=2) to
  incorrectly share cached graphs and produce wrong results (ERR≈679)
- graph cache op_params matching: compare full GGML_MAX_OP_PARAMS
  bytes so that ops differing only in parameters are not incorrectly
  replayed from cache
2026-04-28 09:27:22 +03:00
Georgi Gerganov 14e733e36f spec : refactor params (#22397)
* spec : refactor params

* cont : fix

* cont : rename "sparam" to "sampling"

* cont : add spec params category

* cont : add info about removed arguments

* cont : skip param length check for spec params

* cont : adapt server tests
2026-04-28 09:07:33 +03:00
Aman Gupta 516e8d7a8a server: use pos_next instead of n_tokens for m-rope (#22439) 2026-04-28 08:41:00 +03:00
Rithik Sharma 434b2a1ff6 ggml-webgpu: add Q1_0 support (#22374)
* add fast matmul matvec q1_0 kernel

* ggml-webgpu: drop redundant zero-fills in Q1_0 shmem init
2026-04-27 15:50:59 -07:00
tha80 983ca8992e server: (router) Forward form-data to model server (Fixes #22044) (#22118)
* This commit enables the router to forward form-data to model server.
Fixes #22044 (enabling to use the /v1/audio/transcriptions in router mode)

* * Applied the suggestion from Copilots first comment: using the non-throwing json::parse overload.
* Addressed Copilots third comment by extending the files representation to also include filename and content-type
* Addressed Copilots fourth comment by making the RNG thread_local

* Changed variable body from std::string to std::ostringstream in build_multipart_body
as suggested by ngxson in https://github.com/ggml-org/llama.cpp/pull/22118#discussion_r3127099053

* Added sanitize_field lambda in build_multipart_body for key, filename and content_type
as suggested by ngxson in https://github.com/ggml-org/llama.cpp/pull/22118#discussion_r3127104647

* explicitly checking if value/item is string before calling value/item.get<std::string>()
as requested by ngxson in https://github.com/ggml-org/llama.cpp/pull/22118#discussion_r3127111279

* Added double quote to the sanitize lambda and throw on json parse failure

---------

Co-authored-by: Ralph Paßgang <ralph@trust-it.de>
2026-04-27 23:55:00 +02:00
Rithik Sharma 665abc6097 add fast mat-vec kernels for i-quants (#22344) 2026-04-27 08:25:45 -07:00
Igor Rudenko 4414c04b9a Additional test for common/gemma4 : handle parsing edge cases (#22420)
* Additional test for common/gemma4 : handle parsing edge cases

* Move tests to Gemma 4 test group
2026-04-27 16:36:59 +02:00
unraido ceaf47c4b1 fix: rpc-server cache may not work in Windows environments (#22394)
* fix: create directory and log cache file name.

* Remove GGML_LOG_INFO conditional compilation.

---------

Co-authored-by: kotaro <kotaro.kusunoki@gmail.com>
2026-04-27 17:25:09 +03:00
rankaiyx 42401c72b8 Fix type casting for unaccounted memory calculation (#22424) 2026-04-27 14:31:13 +02:00
Georgi Gerganov e940b3d468 download : prefer q8_0 when q4_k not available (#22428) 2026-04-27 14:30:29 +02:00
ynankani 0f1bb602dd model : remove duplicate wo_s scale after build_attn (Qwen3, LLaMA) (#22421)
Signed-off-by: Yash Nankani <ynankani@nvidia.com>
2026-04-27 09:58:48 +02:00
Sigbjørn Skjæret d13540becd convert : remove input_scale for dequantized fp8 modelopt (#22356) 2026-04-27 08:45:01 +02:00
Adrien Gallouët f84270ea10 ggml : use 64 bytes aligned tile buffers (#21058)
| Model                            | Test   |   t/s OLD |   t/s NEW |   Speedup |
|:---------------------------------|:-------|----------:|----------:|----------:|
| qwen35 0.8B BF16                 | pp512  |    584.59 |    595.41 |      1.02 |
| qwen35 0.8B BF16                 | tg128  |     52.23 |     52.82 |      1.01 |
| qwen35 0.8B IQ2_M - 2.7 bpw      | pp512  |    260.64 |    261.70 |      1.00 |
| qwen35 0.8B IQ2_M - 2.7 bpw      | tg128  |     81.17 |     80.89 |      1.00 |
| qwen35 0.8B IQ2_XXS - 2.0625 bpw | pp512  |    302.36 |    302.56 |      1.00 |
| qwen35 0.8B IQ2_XXS - 2.0625 bpw | tg128  |     84.93 |     85.12 |      1.00 |
| qwen35 0.8B IQ3_XXS - 3.0625 bpw | pp512  |    263.22 |    260.01 |      0.99 |
| qwen35 0.8B IQ3_XXS - 3.0625 bpw | tg128  |     80.29 |     78.94 |      0.98 |
| qwen35 0.8B IQ4_NL - 4.5 bpw     | pp512  |    728.65 |    742.09 |      1.02 |
| qwen35 0.8B IQ4_NL - 4.5 bpw     | tg128  |     82.39 |     84.46 |      1.03 |
| qwen35 0.8B IQ4_XS - 4.25 bpw    | pp512  |    681.33 |    677.06 |      0.99 |
| qwen35 0.8B IQ4_XS - 4.25 bpw    | tg128  |     80.18 |     79.28 |      0.99 |
| qwen35 0.8B Q2_K_M               | pp512  |    413.28 |    415.94 |      1.01 |
| qwen35 0.8B Q2_K_M               | tg128  |     81.90 |     82.78 |      1.01 |
| qwen35 0.8B Q3_K_M               | pp512  |    493.17 |    495.08 |      1.00 |
| qwen35 0.8B Q3_K_M               | tg128  |     82.75 |     83.23 |      1.01 |
| qwen35 0.8B Q3_K_S               | pp512  |    429.35 |    427.64 |      1.00 |
| qwen35 0.8B Q3_K_S               | tg128  |     86.69 |     87.02 |      1.00 |
| qwen35 0.8B Q4_0                 | pp512  |    783.46 |    782.32 |      1.00 |
| qwen35 0.8B Q4_0                 | tg128  |     88.23 |     87.90 |      1.00 |
| qwen35 0.8B Q4_1                 | pp512  |    741.71 |    729.76 |      0.98 |
| qwen35 0.8B Q4_1                 | tg128  |     85.44 |     86.01 |      1.01 |
| qwen35 0.8B Q4_K_M               | pp512  |    676.24 |    681.31 |      1.01 |
| qwen35 0.8B Q4_K_M               | tg128  |     76.59 |     77.06 |      1.01 |
| qwen35 0.8B Q4_K_S               | pp512  |    683.12 |    688.81 |      1.01 |
| qwen35 0.8B Q4_K_S               | tg128  |     80.50 |     81.19 |      1.01 |
| qwen35 0.8B Q5_K_M               | pp512  |    635.33 |    642.11 |      1.01 |
| qwen35 0.8B Q5_K_M               | tg128  |     72.07 |     72.49 |      1.01 |
| qwen35 0.8B Q5_K_S               | pp512  |    660.95 |    658.18 |      1.00 |
| qwen35 0.8B Q5_K_S               | tg128  |     72.19 |     72.95 |      1.01 |
| qwen35 0.8B Q6_K                 | pp512  |    647.97 |    638.84 |      0.99 |
| qwen35 0.8B Q6_K                 | tg128  |     72.83 |     72.49 |      1.00 |
| qwen35 0.8B Q8_0                 | pp512  |    805.01 |    785.49 |      0.98 |
| qwen35 0.8B Q8_0                 | tg128  |     70.10 |     70.13 |      1.00 |

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-27 09:30:55 +03:00
212 changed files with 14457 additions and 9504 deletions
+377 -217
View File
File diff suppressed because it is too large Load Diff
+4 -2
View File
@@ -25,7 +25,8 @@ struct common_arg {
const char * value_hint_2 = nullptr; // for second arg value
const char * env = nullptr;
std::string help;
bool is_sparam = false; // is current arg a sampling param?
bool is_sampling = false; // is current arg a sampling param?
bool is_spec = false; // is current arg a speculative decoding param?
bool is_preset_only = false; // is current arg preset-only (not treated as CLI arg)
void (*handler_void) (common_params & params) = nullptr;
void (*handler_string) (common_params & params, const std::string &) = nullptr;
@@ -74,7 +75,8 @@ struct common_arg {
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
common_arg & set_env(const char * env);
common_arg & set_sparam();
common_arg & set_sampling();
common_arg & set_spec();
common_arg & set_preset_only();
bool in_example(enum llama_example ex);
bool is_exclude(enum llama_example ex);
+7 -7
View File
@@ -70,7 +70,7 @@ common_time_meas::~common_time_meas() {
// CPU utils
//
int32_t cpu_get_num_physical_cores() {
int32_t common_cpu_get_num_physical_cores() {
#ifdef __linux__
// enumerate the set of thread siblings, num entries is num cores
std::unordered_set<std::string> siblings;
@@ -185,11 +185,11 @@ static int cpu_count_math_cpus(int n_cpu) {
/**
* Returns number of CPUs on system that are useful for math.
*/
int32_t cpu_get_num_math() {
int32_t common_cpu_get_num_math() {
#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
int n_cpu = sysconf(_SC_NPROCESSORS_ONLN);
if (n_cpu < 1) {
return cpu_get_num_physical_cores();
return common_cpu_get_num_physical_cores();
}
if (is_hybrid_cpu()) {
cpu_set_t affinity;
@@ -202,7 +202,7 @@ int32_t cpu_get_num_math() {
}
}
#endif
return cpu_get_num_physical_cores();
return common_cpu_get_num_physical_cores();
}
// Helper for setting process priority
@@ -263,7 +263,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
//
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) {
void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_params * role_model) {
int32_t n_set = 0;
if (cpuparams.n_threads < 0) {
@@ -271,7 +271,7 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model)
if (role_model != nullptr) {
cpuparams = *role_model;
} else {
cpuparams.n_threads = cpu_get_num_math();
cpuparams.n_threads = common_cpu_get_num_math();
}
}
@@ -1521,7 +1521,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
return cparams;
}
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) {
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const common_cpu_params & params) {
struct ggml_threadpool_params tpp;
ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults
+56 -36
View File
@@ -54,7 +54,7 @@ struct common_control_vector_load_info;
// CPU utils
//
struct cpu_params {
struct common_cpu_params {
int n_threads = -1;
bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
bool mask_valid = false; // Default: any CPU
@@ -63,8 +63,8 @@ struct cpu_params {
uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
};
int32_t cpu_get_num_physical_cores();
int32_t cpu_get_num_math();
int32_t common_cpu_get_num_physical_cores();
int32_t common_cpu_get_num_math();
//
// Common params
@@ -297,34 +297,19 @@ struct common_params_model {
struct common_ngram_mod;
struct common_params_speculative {
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
// draft-model-based speculative decoding parameters
struct common_params_speculative_draft {
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
// general-purpose speculative decoding parameters
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
common_params_model mparams;
// ngram-based speculative decoding
llama_model * model = nullptr; // a llama_model that can be shared by multiple speculative contexts
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
std::shared_ptr<common_ngram_mod> ngram_mod;
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
// draft-model speculative decoding
struct common_params_model mparams_dft;
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
llama_context_params cparams; // these are the parameters for the draft llama_context
int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
@@ -332,25 +317,60 @@ struct common_params_speculative {
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
common_cpu_params cpuparams;
common_cpu_params cpuparams_batch;
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
};
struct common_params_speculative_ngram_mod {
int32_t n_match = 24;
int32_t n_max = 64;
int32_t n_min = 48;
// shared instance of the ngram container for all speculative decoding contexts
std::shared_ptr<common_ngram_mod> obj;
};
struct common_params_speculative_ngram_map {
uint16_t size_n = 12; // ngram size for lookup
uint16_t size_m = 48; // mgram size for speculative tokens
uint16_t min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
};
struct common_params_speculative_ngram_cache {
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding
};
struct common_params_speculative {
// TODO: become a vector in order to support "chains of speculators"
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
common_params_speculative_draft draft;
common_params_speculative_ngram_mod ngram_mod;
common_params_speculative_ngram_map ngram_simple;
common_params_speculative_ngram_map ngram_map_k;
common_params_speculative_ngram_map ngram_map_k4v;
common_params_speculative_ngram_cache ngram_cache;
bool has_dft() const {
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
}
};
struct common_params_vocoder {
struct common_params_model model;
std::string speaker_file = ""; // speaker file path // NOLINT
std::string speaker_file; // speaker file path
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy
};
struct common_params_diffusion {
@@ -433,8 +453,8 @@ struct common_params {
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
common_cpu_params cpuparams;
common_cpu_params cpuparams_batch;
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;
@@ -678,7 +698,7 @@ std::string common_params_get_system_info(const common_params & params);
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_params * role_model = nullptr);
bool set_process_priority(enum ggml_sched_priority prio);
//
@@ -846,7 +866,7 @@ common_init_result_ptr common_init_from_params(common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const common_cpu_params & params);
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
+1 -1
View File
@@ -627,7 +627,7 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
if (!tag.empty()) {
tags.push_back(tag);
} else {
tags = {"Q4_K_M", "Q4_0"};
tags = {"Q4_K_M", "Q8_0"};
}
for (const auto & t : tags) {
+2 -2
View File
@@ -856,7 +856,7 @@ void common_memory_breakdown_print(const struct llama_context * ctx) {
ggml_backend_dev_memory(dev, &free, &total);
const size_t self = mb.model + mb.context + mb.compute;
const size_t unaccounted = total - self - free;
const int64_t unaccounted = static_cast<int64_t>(total) - static_cast<int64_t>(free) - static_cast<int64_t>(self);
table_data.push_back({
template_gpu,
@@ -867,7 +867,7 @@ void common_memory_breakdown_print(const struct llama_context * ctx) {
std::to_string(mb.model / MiB),
std::to_string(mb.context / MiB),
std::to_string(mb.compute / MiB),
std::to_string(unaccounted / MiB)});
std::to_string(unaccounted / static_cast<int64_t>(MiB))});
}
// print memory breakdown for host:
+11 -8
View File
@@ -49,7 +49,7 @@ enum common_log_col : int {
};
// disable colors by default
static std::vector<const char *> g_col = {
static const char* g_col[] = {
"",
"",
"",
@@ -247,7 +247,6 @@ public:
entries = std::move(new_entries);
}
cv.notify_one();
}
@@ -265,7 +264,6 @@ public:
{
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [this]() { return head != tail; });
cur = entries[head];
head = (head + 1) % entries.size();
@@ -301,7 +299,6 @@ public:
tail = (tail + 1) % entries.size();
}
cv.notify_one();
}
@@ -338,7 +335,7 @@ public:
g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN;
g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE;
} else {
for (size_t i = 0; i < g_col.size(); i++) {
for (size_t i = 0; i < std::size(g_col); i++) {
g_col[i] = "";
}
}
@@ -368,14 +365,20 @@ struct common_log * common_log_init() {
}
struct common_log * common_log_main() {
static struct common_log log;
// We intentionally leak (i.e. do not delete) the logger singleton because
// common_log destructor called at DLL teardown phase will cause hanging on Windows.
// OS will release resources anyway so it should not be a significant issue,
// though this design may cause logs to be lost if not flushed before the program exits.
// Refer to https://github.com/ggml-org/llama.cpp/issues/22142 for details.
static struct common_log * log;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
log = new common_log;
// Set default to auto-detect colors
log.set_colors(tty_can_use_colors());
log->set_colors(tty_can_use_colors());
});
return &log;
return log;
}
void common_log_pause(struct common_log * log) {
+5 -1
View File
@@ -49,7 +49,11 @@ void common_log_default_callback(enum ggml_log_level level, const char * text, v
struct common_log;
struct common_log * common_log_init();
struct common_log * common_log_main(); // singleton, automatically destroys itself on exit
// Singleton, intentionally leaked to avoid Windows teardown hangs.
// Call common_log_flush() before exit if you want to ensure all logs are flushed.
struct common_log * common_log_main();
void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe
void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe
void common_log_free (struct common_log * log);
+1 -1
View File
@@ -43,7 +43,7 @@ static std::set<std::string> get_remote_preset_whitelist(const std::map<std::str
for (const auto & it : key_to_opt) {
const std::string & key = it.first;
const common_arg & opt = it.second;
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
if (allowed_options.find(key) != allowed_options.end() || opt.is_sampling) {
allowed_keys.insert(key);
// also add variant keys (args without leading dashes and env vars)
for (const auto & arg : opt.get_args()) {
+14
View File
@@ -122,6 +122,20 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
}
break;
case REASONING_BUDGET_DONE:
// Re-arm on a new start tag: some models emit multiple <think> blocks
// per response, and each should get a fresh budget window.
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
}
}
break;
}
}
+140 -46
View File
@@ -151,6 +151,9 @@ struct common_speculative_state {
llama_tokens & result) = 0;
virtual void accept(uint16_t n_accepted) = 0;
virtual int32_t n_max(const common_params_speculative & params) const = 0;
virtual int32_t n_min(const common_params_speculative & params) const = 0;
};
struct common_speculative_checkpoint {
@@ -296,6 +299,8 @@ struct common_speculative_state_draft : public common_speculative_state {
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
const auto & sparams = params.draft;
auto * spec = this;
auto & batch = spec->batch;
@@ -309,7 +314,7 @@ struct common_speculative_state_draft : public common_speculative_state {
int reuse_i = 0; // index of part to be reused in prompt_dft
int reuse_n = 0; // length of part to be reused in prompt_dft
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
const int n_ctx = llama_n_ctx(ctx_dft) - sparams.n_max;
llama_tokens prompt_cnv;
if (!spec->vocab_cmpt) {
@@ -367,7 +372,7 @@ struct common_speculative_state_draft : public common_speculative_state {
}
result.clear();
result.reserve(params.n_max);
result.reserve(sparams.n_max);
bool needs_ckpt = use_ckpt && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_ckpt && reuse_i > 0)) {
@@ -380,7 +385,7 @@ struct common_speculative_state_draft : public common_speculative_state {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
result.push_back(prompt_dft[i]);
if (params.n_max <= (int) result.size()) {
if (sparams.n_max <= (int) result.size()) {
break;
}
}
@@ -473,7 +478,7 @@ struct common_speculative_state_draft : public common_speculative_state {
common_sampler_reset(smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_max; ++i) {
for (int i = 0; i < sparams.n_max; ++i) {
common_batch_clear(batch);
common_sampler_sample(smpl, ctx_dft, 0, true);
@@ -492,12 +497,12 @@ struct common_speculative_state_draft : public common_speculative_state {
result.push_back(id);
if (params.n_max <= (int) result.size()) {
if (sparams.n_max <= (int) result.size()) {
break;
}
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
if (cur_p->data[0].p < sparams.p_min) {
break;
}
@@ -518,10 +523,14 @@ struct common_speculative_state_draft : public common_speculative_state {
detokenized = replace_to_tgt(detokenized);
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
result = common_tokenize(ctx_tgt, detokenized, false, true);
if (result.size() > (size_t)params.n_max) {
result.resize(params.n_max);
if (result.size() > (size_t) sparams.n_max) {
result.resize(sparams.n_max);
}
}
if (result.size() < (size_t) sparams.n_min) {
result.clear();
}
}
void accept(uint16_t n_accepted) override {
@@ -529,6 +538,14 @@ struct common_speculative_state_draft : public common_speculative_state {
GGML_UNUSED(n_accepted);
}
int32_t n_max(const common_params_speculative & params) const override {
return params.draft.n_max;
}
int32_t n_min(const common_params_speculative & params) const override {
return params.draft.n_min;
}
std::string replace_to_dft(const std::string & input) const {
std::string result = input;
@@ -581,6 +598,14 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
// noop
GGML_UNUSED(n_accepted);
}
int32_t n_max(const common_params_speculative & params) const override {
return params.draft.n_max;
}
int32_t n_min(const common_params_speculative & params) const override {
return params.draft.n_min;
}
};
// state of self-speculation (simple implementation, not ngram-map)
@@ -610,19 +635,27 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
// noop
GGML_UNUSED(n_accepted);
}
int32_t n_max(const common_params_speculative & /*params*/) const override {
return config.size_mgram;
}
int32_t n_min(const common_params_speculative & /*params*/) const override {
return config.size_mgram;
}
};
struct common_speculative_state_ngram_map_k : public common_speculative_state {
// draft ngram map for speculative decoding without draft model
common_ngram_map map;
common_ngram_map config;
common_speculative_state_ngram_map_k(
enum common_speculative_type type,
common_ngram_map map)
: common_speculative_state(type), map(std::move(map)) {}
common_ngram_map config)
: common_speculative_state(type), config(std::move(config)) {}
void begin(const llama_tokens & prompt) override {
common_ngram_map_begin(map, prompt);
common_ngram_map_begin(config, prompt);
}
void draft(
@@ -630,12 +663,20 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
common_ngram_map_draft(map, prompt_tgt, id_last, result);
common_ngram_map_draft(config, prompt_tgt, id_last, result);
GGML_UNUSED(params);
}
void accept(uint16_t n_accepted) override {
common_ngram_map_accept(map, n_accepted);
common_ngram_map_accept(config, n_accepted);
}
int32_t n_max(const common_params_speculative & /*params*/) const override {
return config.size_value;
}
int32_t n_min(const common_params_speculative & /*params*/) const override {
return config.size_value;
}
};
@@ -692,7 +733,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(params);
const auto & sparams = params.ngram_mod;
n_draft_last = 0;
@@ -712,16 +753,16 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
i_last = cur_len - n;
}
result.resize(n + params.n_max);
result.resize(n + sparams.n_max);
for (size_t i = 0; i < n - 1; ++i) {
result[i] = prompt_tgt[cur_len - n + 1 + i];
}
result[n - 1] = id_last;
for (int i = 0; i < params.n_max; ++i) {
for (int i = 0; i < sparams.n_max; ++i) {
const llama_token token = mod.get(result.data() + i);
if (token == common_ngram_mod::EMPTY) {
if (i < params.n_min) {
if (i < sparams.n_min) {
result.clear();
return;
}
@@ -764,6 +805,14 @@ struct common_speculative_state_ngram_mod : public common_speculative_state {
}
}
}
int32_t n_max(const common_params_speculative & params) const override {
return params.ngram_mod.n_max;
}
int32_t n_min(const common_params_speculative & params) const override {
return params.ngram_mod.n_min;
}
};
struct common_speculative_state_ngram_cache : public common_speculative_state {
@@ -857,6 +906,14 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
// TODO: noop
GGML_UNUSED(n_accepted);
}
int32_t n_max(const common_params_speculative & /*params*/) const override {
return n_draft;
}
int32_t n_min(const common_params_speculative & /*params*/) const override {
return 0;
}
};
struct common_speculative {
@@ -865,11 +922,13 @@ struct common_speculative {
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
};
static common_ngram_map get_common_ngram_map(const common_speculative_config & config) {
uint16_t size_key = config.params.ngram_size_n;
uint16_t size_value = config.params.ngram_size_m;
bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
uint16_t min_hits = config.params.ngram_min_hits;
static common_ngram_map get_common_ngram_map(
common_speculative_type type,
const common_params_speculative_ngram_map & config) {
uint16_t size_key = config.size_n;
uint16_t size_value = config.size_m;
bool key_only = type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
uint16_t min_hits = config.min_hits;
return common_ngram_map(size_key, size_value, key_only, min_hits);
}
@@ -927,8 +986,8 @@ common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt) {
llama_context * ctx_dft = nullptr;
if (params.model_dft) {
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
if (params.draft.model) {
ctx_dft = llama_init_from_model(params.draft.model, params.draft.cparams);
if (ctx_dft == nullptr) {
LOG_ERR("%s", "failed to create draft context\n");
return nullptr;
@@ -938,7 +997,7 @@ common_speculative * common_speculative_init(
// Compute the implementations to use based on the config and their order of preference
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft = !params.draft.mparams.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
@@ -961,16 +1020,17 @@ common_speculative * common_speculative_init(
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
}
if (has_ngram_mod) {
// shared instance for all speculative decoding contexts
if (!params.ngram_mod) {
params.ngram_mod = std::make_shared<common_ngram_mod>(params.ngram_size_n, 4*1024*1024);
auto & sparams = params.ngram_mod;
LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__,
params.ngram_size_n, params.ngram_mod->size(),
(float)(params.ngram_mod->size_bytes())/1024/1024);
if (!sparams.obj) {
sparams.obj = std::make_shared<common_ngram_mod>(sparams.n_match, 4*1024*1024);
if (params.ngram_size_n < 16) {
LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, params.ngram_size_n);
LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__,
sparams.n_match, sparams.obj->size(), (float)(sparams.obj->size_bytes())/1024/1024);
if (sparams.n_match < 16) {
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, sparams.n_match);
}
}
@@ -1000,7 +1060,7 @@ common_speculative * common_speculative_init(
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .replacements = */ params.draft.replacements,
/* .use_ckpt = */ use_ckpt
));
break;
@@ -1010,18 +1070,18 @@ common_speculative * common_speculative_init(
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config);
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
uint16_t ngram_size_key = ngram_map.size_key;
uint16_t mgram_size_value = ngram_map.size_value;
auto config_simple = common_ngram_simple_config {
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value
};
auto state = std::make_unique<common_speculative_state_ngram_simple>(
/* .type = */ config.type,
/* .state = */ config_simple
/* .type = */ config.type,
/* .state = */ config_simple
);
impls.push_back(std::move(state));
break;
@@ -1030,18 +1090,17 @@ common_speculative * common_speculative_init(
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
impls.push_back(std::make_unique<common_speculative_state_ngram_map_k>(
(config.type),
get_common_ngram_map(config)
get_common_ngram_map(config.type, config.params.ngram_map_k)
));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
GGML_ASSERT(config.params.ngram_mod);
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod));
GGML_ASSERT(config.params.ngram_mod.obj);
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod.obj));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
auto state = create_state_ngram_cache(
params.lookup_cache_static, params.lookup_cache_dynamic, config);
auto state = create_state_ngram_cache(params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic, config);
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
break;
}
@@ -1099,6 +1158,15 @@ llama_tokens common_speculative_draft(
impl->n_call_draft++;
}
{
const int n_min = impl->n_min(params);
if (!result.empty() && (int) result.size() < n_min) {
LOG_DBG("%s: ignoring small draft: %d < %d\n", __func__, (int) result.size(), n_min);
result.clear();
}
}
if (!result.empty()) {
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
@@ -1108,7 +1176,7 @@ llama_tokens common_speculative_draft(
impl->n_gen_drafts++;
impl->n_gen_tokens += result.size();
break; // We have a draft, so break out of the loop and return it.
break; // we have a draft, so break out of the loop and return it.
}
}
@@ -1136,6 +1204,32 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
}
}
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) {
if (spec == nullptr) {
return 0;
}
int32_t n_max = 0;
for (const auto & impl : spec->impls) {
n_max = std::max(n_max, impl->n_max(params));
}
return n_max;
}
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params) {
if (spec == nullptr) {
return 0;
}
int32_t n_min = 0;
for (const auto & impl : spec->impls) {
n_min = std::max(n_min, impl->n_min(params));
}
return n_min;
}
void common_speculative_print_stats(const common_speculative * spec) {
if (spec == nullptr) {
return;
+3
View File
@@ -33,6 +33,9 @@ llama_tokens common_speculative_draft(
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params);
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);
+49 -32
View File
@@ -272,6 +272,22 @@ class ModelBase:
return tensors
@staticmethod
def _scale_is_trivial(scale: Tensor) -> bool:
return scale.numel() <= 1 and abs(float(scale.float().sum()) - 1.0) < 1e-6
def _write_scale_tensor(self, scale_name: str, scale: Tensor):
if not self._scale_is_trivial(scale):
scale_f32 = scale.float().numpy().flatten()
logger.info(f" + {scale_name} (per-tensor scale, shape [{scale_f32.size}])")
self.gguf_writer.add_tensor(scale_name, scale_f32)
def _write_scales_tensor(self, scale_name: str, scales: list[float]):
if not np.allclose(scales, 1.0, atol=1e-6):
scale_vals = np.array(scales, dtype=np.float32)
logger.info(f" + {scale_name} (per-expert scale, shape [{len(scales)}])")
self.gguf_writer.add_tensor(scale_name, scale_vals)
def dequant_model(self):
# If all quantized tensors were already handled (e.g. pure NVFP4), skip
if self._is_nvfp4 and not any(k.endswith((".weight_scale", ".weight_scale_inv")) for k in self.model_tensors):
@@ -494,7 +510,7 @@ class ModelBase:
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), None)
tensors_to_remove.append(name)
if name.endswith((".k_scale", ".v_scale")):
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
tensors_to_remove.append(name)
elif quant_method is not None:
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
@@ -602,10 +618,6 @@ class ModelBase:
raw = np.concatenate([d_grouped, qs_grouped], axis=-1).reshape(out_features, n_super * 36)
return raw, [out_features, n_super * 64]
@staticmethod
def _nvfp4_scale2_is_trivial(scale2: Tensor) -> bool:
return scale2.numel() <= 1 and abs(float(scale2.float().sum()) - 1.0) < 1e-6
def _repack_nvfp4(self, name: str, weight: Tensor, scale: Tensor, scale2: Tensor, input_scale: Tensor):
if "language_model." in name:
name = name.replace("language_model.", "")
@@ -616,19 +628,8 @@ class ModelBase:
logger.info(f"Repacked {new_name} with shape {shape} and quantization NVFP4")
self.gguf_writer.add_tensor(new_name, raw, raw_dtype=gguf.GGMLQuantizationType.NVFP4)
# Emit per-tensor scale2 as a separate F32 tensor when non-trivial
if not self._nvfp4_scale2_is_trivial(scale2):
scale2_f32 = scale2.float().numpy().flatten()
scale_name = new_name.replace(".weight", ".scale")
logger.info(f" + {scale_name} (per-tensor NVFP4 scale2, shape [{scale2_f32.size}])")
self.gguf_writer.add_tensor(scale_name, scale2_f32)
# Emit per-tensor input_scale as a separate F32 tensor when non-trivial
if not self._nvfp4_scale2_is_trivial(input_scale):
input_scale_f32 = input_scale.float().numpy().flatten()
input_scale_name = new_name.replace(".weight", ".input_scale")
logger.info(f" + {input_scale_name} (per-tensor NVFP4 input_scale, shape [{input_scale_f32.size}])")
self.gguf_writer.add_tensor(input_scale_name, input_scale_f32)
self._write_scale_tensor(new_name.replace(".weight", ".scale"), scale2)
self._write_scale_tensor(new_name.replace(".weight", ".input_scale"), input_scale)
def _generate_nvfp4_tensors(self):
# Per-layer expert merging to avoid holding all experts in memory
@@ -719,24 +720,17 @@ class ModelBase:
logger.info(f"Repacked {new_name} with shape [{len(experts)}, {shape[0]}, {shape[1]}] and quantization NVFP4")
self.gguf_writer.add_tensor(new_name, merged, raw_dtype=gguf.GGMLQuantizationType.NVFP4)
# Emit per-expert scale2 tensor if any expert has non-trivial scale2
scales.sort(key=lambda x: x[0])
scale_vals = np.array([s[1] for s in scales], dtype=np.float32)
if not np.allclose(scale_vals, 1.0, atol=1e-6):
scale_name = new_name.replace(".weight", ".scale")
logger.info(f" + {scale_name} (per-expert NVFP4 scale2, shape [{len(scales)}])")
self.gguf_writer.add_tensor(scale_name, scale_vals)
self._write_scales_tensor(new_name.replace(".weight", ".scale"), [s[1] for s in scales])
# Emit per-expert input_scale tensor if any expert has non-trivial input_scale
input_scales.sort(key=lambda x: x[0])
input_scale_vals = np.array([s[1] for s in input_scales], dtype=np.float32)
if not np.allclose(input_scale_vals, 1.0, atol=1e-6):
input_scale_name = new_name.replace(".weight", ".input_scale")
logger.info(f" + {input_scale_name} (per-expert NVFP4 input_scale, shape [{len(input_scales)}])")
self.gguf_writer.add_tensor(input_scale_name, input_scale_vals)
self._write_scales_tensor(new_name.replace(".weight", ".input_scale"), [s[1] for s in input_scales])
del experts, merged
def _needs_nvfp4_processing(self) -> bool:
return True
def prepare_tensors(self):
# detect NVFP4 quantization (ModelOpt format)
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
@@ -767,7 +761,7 @@ class ModelBase:
# NVFP4 weights are repacked and written directly to gguf_writer.
# This must run before dequant_model so NVFP4 tensors are removed
# from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
if self._is_nvfp4:
if self._is_nvfp4 and self._needs_nvfp4_processing():
self._generate_nvfp4_tensors()
self.dequant_model()
@@ -2199,6 +2193,10 @@ class MmprojModel(ModelBase):
# merge configs
self.preprocessor_config = {**self.preprocessor_config, **cfg}
def _needs_nvfp4_processing(self) -> bool:
# nvfp4 quantization applies to the text model only.
return False
def get_vision_config(self) -> dict[str, Any] | None:
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
return self.global_config.get(config_name)
@@ -4459,6 +4457,12 @@ class NemotronNanoV2VLModel(MmprojModel):
}
return vision_config
def dequant_model(self):
if self._is_nvfp4:
# Skip nvfp4 quantization for vision/audio model.
return
super().dequant_model()
def set_gguf_parameters(self):
if "image_mean" not in self.preprocessor_config:
self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406]
@@ -4482,6 +4486,10 @@ class NemotronNanoV2VLModel(MmprojModel):
if "input_conditioner" in name:
return
# mtmd does not support video yet so skip tensors related to video.
if "radio_model.model.patch_generator.video_embedder" in name:
return
# RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it
if "patch_generator.pos_embed" in name:
if not name.endswith(".weight"):
@@ -10829,7 +10837,11 @@ class NemotronHModel(GraniteHybridModel):
# uses self.model_arch to build the tensor name map, and all MoE-specific
# mappings would be missed if it were called with the default non-MoE arch.
hparams = ModelBase.load_hparams(args[0], self.is_mistral_format)
if "num_experts_per_tok" in hparams:
has_moe_params = (
"num_experts_per_tok" in hparams
or (isinstance(hparams.get("llm_config"), dict) and "num_experts_per_tok" in hparams["llm_config"])
)
if has_moe_params:
self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE
self.is_moe = True
@@ -10976,6 +10988,11 @@ class NemotronHModel(GraniteHybridModel):
if name.startswith(("vision_model.", "mlp1.")):
return
if name.startswith(("sound_encoder.")):
return
if name.startswith(("sound_projection.")):
return
# Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL)
if name.startswith("language_model."):
name = name[len("language_model."):]
+4 -4
View File
@@ -73,12 +73,12 @@ static void write_help(std::ostringstream & ss, const md_file & md) {
auto ctx_arg = common_params_parser_init(params, md.ex);
std::vector<common_arg *> common_options;
std::vector<common_arg *> sparam_options;
std::vector<common_arg *> sampling_options;
std::vector<common_arg *> specific_options;
for (auto & opt : ctx_arg.options) {
// in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example
if (opt.is_sparam) {
sparam_options.push_back(&opt);
if (opt.is_sampling) {
sampling_options.push_back(&opt);
} else if (opt.in_example(ctx_arg.ex)) {
specific_options.push_back(&opt);
} else {
@@ -93,7 +93,7 @@ static void write_help(std::ostringstream & ss, const md_file & md) {
ss << "### Common params\n\n";
write_table(ss, common_options);
ss << "\n\n### Sampling params\n\n";
write_table(ss, sparam_options);
write_table(ss, sampling_options);
ss << "\n\n### " << md.specific_section_header << "\n\n";
write_table(ss, specific_options);
+2 -2
View File
@@ -37,9 +37,9 @@ int main(int argc, char ** argv){
common_ngram_cache ngram_cache;
common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str());
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.ngram_cache.lookup_cache_static.c_str());
common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static);
common_ngram_cache_save(ngram_cache, params.speculative.ngram_cache.lookup_cache_static);
return 0;
}
+6 -6
View File
@@ -24,7 +24,7 @@ int main(int argc, char ** argv){
return 1;
}
const int n_draft = params.speculative.n_max;
const int n_draft = params.speculative.draft.n_max;
// init llama.cpp
llama_backend_init();
@@ -49,18 +49,18 @@ int main(int argc, char ** argv){
{
const int64_t t_start_draft_us = ggml_time_us();
if (!params.speculative.lookup_cache_static.empty()) {
if (!params.speculative.ngram_cache.lookup_cache_static.empty()) {
try {
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
ngram_cache_static = common_ngram_cache_load(params.speculative.ngram_cache.lookup_cache_static);
} catch (std::ifstream::failure const &) {
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
LOG_ERR("failed to open static lookup cache: %s", params.speculative.ngram_cache.lookup_cache_static.c_str());
exit(1);
}
}
if (!params.speculative.lookup_cache_dynamic.empty()) {
if (!params.speculative.ngram_cache.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.ngram_cache.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
}
+7 -7
View File
@@ -25,7 +25,7 @@ int main(int argc, char ** argv){
}
// max. number of additional tokens to draft if match is found
const int n_draft = params.speculative.n_max;
const int n_draft = params.speculative.draft.n_max;
// init llama.cpp
llama_backend_init();
@@ -54,18 +54,18 @@ int main(int argc, char ** argv){
const int64_t t_start_draft_us = ggml_time_us();
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
if (!params.speculative.lookup_cache_static.empty()) {
if (!params.speculative.ngram_cache.lookup_cache_static.empty()) {
try {
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
ngram_cache_static = common_ngram_cache_load(params.speculative.ngram_cache.lookup_cache_static);
} catch (std::ifstream::failure const &) {
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
LOG_ERR("failed to open static lookup cache: %s", params.speculative.ngram_cache.lookup_cache_static.c_str());
exit(1);
}
}
if (!params.speculative.lookup_cache_dynamic.empty()) {
if (!params.speculative.ngram_cache.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.ngram_cache.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
}
@@ -213,7 +213,7 @@ int main(int argc, char ** argv){
// Update dynamic ngram cache with context ngram cache and save it to disk:
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic);
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.ngram_cache.lookup_cache_dynamic);
LOG("\n\n");
@@ -43,7 +43,7 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.speculative.mparams_dft.path.empty()) {
if (params.speculative.draft.mparams.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
@@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
// TODO: simplify this logic
{
const auto & params_spec = params.speculative;
const auto & params_spec = params.speculative.draft;
auto params_dft = params;
@@ -85,15 +85,15 @@ int main(int argc, char ** argv) {
params_dft.n_ctx = params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams_dft;
params_dft.model = params_spec.mparams;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
if (params_spec.cpuparams.n_threads > 0) {
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params_dft.cpuparams.n_threads = params.speculative.draft.cpuparams.n_threads;
params_dft.cpuparams_batch.n_threads = params.speculative.draft.cpuparams_batch.n_threads;
}
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
params_dft.tensor_buft_overrides = params.speculative.draft.tensor_buft_overrides;
auto mparams_dft = common_model_params_to_llama(params_dft);
@@ -103,8 +103,8 @@ int main(int argc, char ** argv) {
return 1;
}
params.speculative.model_dft = model_dft.get();
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
params.speculative.draft.model = model_dft.get();
params.speculative.draft.cparams = common_context_params_to_llama(params_dft);
}
// Tokenize the prompt
@@ -187,16 +187,6 @@ int main(int argc, char ** argv) {
// generate a new draft
draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
if ((int) draft.size() > params_spec.n_max) {
LOG_WRN("draft size %zu exceeds max %d, truncating\n", draft.size(), params_spec.n_max);
draft.resize(params_spec.n_max);
}
if ((int) draft.size() < params_spec.n_min) {
LOG_DBG("ignoring small draft: %zu < %d\n", draft.size(), params_spec.n_min);
draft.clear();
}
// save the original draft size
n_draft = draft.size();
@@ -220,19 +210,12 @@ int main(int argc, char ** argv) {
}
}
GGML_ASSERT(n_draft > 0);
// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
// do not waste time on small drafts
if (draft.size() < (size_t) params_spec.n_min) {
draft.clear();
}
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
}
@@ -340,7 +323,7 @@ int main(int argc, char ** argv) {
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_INF("\n");
LOG_INF("n_draft = %d\n", params_spec.n_max);
LOG_INF("n_draft = %d\n", params_spec.draft.n_max);
LOG_INF("n_predict = %d\n", n_predict);
LOG_INF("n_drafted = %d\n", n_drafted);
LOG_INF("n_accept = %d\n", n_accept);
+10 -10
View File
@@ -49,7 +49,7 @@ int main(int argc, char ** argv) {
return 1;
}
if (params.speculative.mparams_dft.path.empty()) {
if (params.speculative.draft.mparams.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
@@ -58,7 +58,7 @@ int main(int argc, char ** argv) {
const int n_seq_dft = params.n_parallel;
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_draft_split = params.speculative.p_split;
const float p_draft_split = params.speculative.draft.p_split;
std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed);
std::uniform_real_distribution<> u_dist;
@@ -80,15 +80,15 @@ int main(int argc, char ** argv) {
ctx_tgt = llama_init_tgt->context();
// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.mparams_dft;
params.n_gpu_layers = params.speculative.n_gpu_layers;
if (params.speculative.cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
params.devices = params.speculative.draft.devices;
params.model = params.speculative.draft.mparams;
params.n_gpu_layers = params.speculative.draft.n_gpu_layers;
if (params.speculative.draft.cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.speculative.draft.cpuparams.n_threads;
}
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
params.cpuparams_batch.n_threads = params.speculative.draft.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.draft.tensor_buft_overrides;
auto llama_init_dft = common_init_from_params(params);
@@ -183,7 +183,7 @@ int main(int argc, char ** argv) {
//GGML_ASSERT(n_vocab == llama_vocab_n_tokens(model_dft));
// how many tokens to draft each time
int n_draft = params.speculative.n_max;
int n_draft = params.speculative.draft.n_max;
int n_predict = 0;
int n_drafted = 0;
+4 -5
View File
@@ -470,11 +470,10 @@ endforeach()
target_link_libraries(ggml-base PRIVATE Threads::Threads)
find_library(MATH_LIBRARY m)
if (MATH_LIBRARY)
if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT})
target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY})
endif()
if (DEFINED MATH_LIBRARY)
target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY})
elseif (NOT WIN32 AND NOT DEFINED ENV{ONEAPI_ROOT})
target_link_libraries(ggml-base PRIVATE m)
endif()
if (CMAKE_SYSTEM_NAME MATCHES "Android")
+18 -1
View File
@@ -1826,7 +1826,24 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
continue;
}
i = get_i_delayed(i);
const int i_delayed = get_i_delayed(i);
// If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices.
// A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has
// its compute flag disabled and thus gets its data zeroed out.
// If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled.
if (i_delayed > i) {
for (size_t j = 0; j < n_backends; j++) {
auto & bcj = backend_ctx->backend_configs[j];
if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
for (int ii = i + 1; ii <= i_delayed; ii++) {
bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE;
}
}
}
}
i = i_delayed;
for (size_t j = 0; j < n_backends; j++) {
auto & bcj = backend_ctx->backend_configs[j];
+12
View File
@@ -181,6 +181,12 @@ struct ggml_backend_registry {
return;
}
for (auto & entry : backends) {
if (entry.reg == reg) {
return;
}
}
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
__func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
@@ -192,6 +198,12 @@ struct ggml_backend_registry {
}
void register_device(ggml_backend_dev_t device) {
for (auto & dev : devices) {
if (dev == device) {
return;
}
}
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
#endif
+512 -248
View File
@@ -25,6 +25,7 @@
#include "ggml-impl.h"
#include "ggml.h"
#include <aclnnop/aclnn_add.h>
#include <aclnnop/aclnn_add_rms_norm.h>
#include <aclnnop/aclnn_addcdiv.h>
@@ -45,7 +46,9 @@
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <aclnnop/aclnn_ger.h>
#include <aclnnop/aclnn_group_norm.h>
#include <aclnnop/aclnn_gather_v2.h>
#include <aclnnop/aclnn_grouped_matmul_v3.h>
#include <aclnnop/aclnn_scatter.h>
#include <aclnnop/aclnn_gt_scalar.h>
#include <aclnnop/aclnn_im2col.h>
#include <aclnnop/aclnn_index_copy.h>
@@ -62,6 +65,7 @@
#include <aclnnop/aclnn_permute.h>
#include <aclnnop/aclnn_pow.h>
#include <aclnnop/aclnn_pow_tensor_tensor.h>
#include <aclnnop/aclnn_recurrent_gated_delta_rule.h>
#include <aclnnop/aclnn_reduce_sum.h>
#include <aclnnop/aclnn_reflection_pad1d.h>
#include <aclnnop/aclnn_repeat.h>
@@ -69,11 +73,15 @@
#include <aclnnop/aclnn_rms_norm.h>
#include <aclnnop/aclnn_roll.h>
#include <aclnnop/aclnn_softmax.h>
#include <aclnnop/aclnn_softmax_cross_entropy_with_logits.h>
#include <aclnnop/aclnn_sub.h>
#include <aclnnop/aclnn_sum.h>
#include <aclnnop/aclnn_threshold.h>
#include <aclnnop/aclnn_tril.h>
#include <aclnnop/aclnn_triangular_solve.h>
#include <aclnnop/aclnn_triu.h>
#include <aclnnop/aclnn_logical_not.h>
#include <aclnnop/aclnn_masked_fill_scalar.h>
#include <aclnnop/aclnn_upsample_nearest_2d.h>
#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
#include <aclnnop/aclnn_zero.h>
@@ -151,6 +159,107 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst.get(), acl_src1.get());
}
// Fused SwiGLU using aclnnSwiGlu: splits input along innermost dim, applies
// SiLU to left half, multiplies by right half.
//
// Falls back to the generic two-kernel path when src[1] != nullptr (two
// independent halves) or swapped != 0 (reversed activation order), as
// aclnnSwiGlu only handles the single interleaved tensor in standard order.
//
// CANN tiling for SwiGlu requires (storageShapeDim + viewDims) to be even.
// aclCreateTensor always uses storageShapeDim=1, so viewDims must be odd.
// We use a 3D view (1+3=4, even) to satisfy this constraint while preserving
// correct split semantics along the innermost (ne[0]) dimension.
void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
auto silu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Silu, acl_src, acl_dst);
};
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
if (dst->src[1] != nullptr || swapped != 0) {
ggml_cann_op_unary_gated(silu_fn, ctx, dst);
return;
}
// aclnnSwiGlu requires the split dim (src->ne[0]) to be even; fall back otherwise.
if (dst->src[0]->ne[0] % 2 != 0) {
ggml_cann_op_unary_gated(silu_fn, ctx, dst);
return;
}
ggml_tensor * src0 = dst->src[0];
size_t elem_size = ggml_element_size(src0);
// src0 GGML: [2*ne0, ne1, ne2, ne3] → 3D view [2*ne0, ne1, ne2*ne3]
// CANN reversed: [ne2*ne3, ne1, 2*ne0], split along CANN dim 2 (last).
int64_t ne0_x2 = src0->ne[0];
int64_t ne1 = src0->ne[1];
int64_t ne23 = src0->ne[2] * src0->ne[3];
int64_t src3d_ne[] = { ne0_x2, ne1, ne23 };
size_t src3d_nb[] = { (size_t)src0->nb[0], (size_t)src0->nb[1], (size_t)src0->nb[2] };
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
elem_size, src3d_ne, src3d_nb, 3);
// dst GGML: [ne0, ne1, ne2, ne3] → 3D view [ne0, ne1, ne2*ne3]
int64_t ne0 = dst->ne[0];
int64_t dst3d_ne[] = { ne0, ne1, ne23 };
size_t dst3d_nb[] = { (size_t)dst->nb[0], (size_t)dst->nb[1], (size_t)dst->nb[2] };
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
elem_size, dst3d_ne, dst3d_nb, 3);
// CANN tensor [ne23, ne1, 2*ne0]: split along CANN dim 2 (last) = 2*ne0.
GGML_CANN_CALL_ACLNN_OP(ctx, SwiGlu, acl_src.get(), (int64_t)2, acl_dst.get());
}
// Fused GeGLU using aclnnGeGluV3: splits input along ne[0] (CANN last dim),
// activates the LEFT half with GELU, multiplies by right half.
// approximate: 0=tanh, 1=none(erf). activateLeft=true matches GGML convention.
// outGelu is a required-but-discard output buffer.
//
// Falls back to the generic two-kernel path when src[1] != nullptr (two
// independent halves) or swapped != 0 (reversed activation order), as
// aclnnGeGluV3 only handles the single interleaved tensor in standard order.
void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate) {
auto gelu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Gelu, acl_src, acl_dst);
};
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
if (dst->src[1] != nullptr || swapped != 0) {
ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
return;
}
// aclnnGeGluV3 requires the split dim (src->ne[0]) to be even; fall back otherwise.
if (dst->src[0]->ne[0] % 2 != 0) {
ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
return;
}
ggml_tensor * src0 = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
// Allocate a temporary buffer for the required outGelu output (same shape as dst).
// Build contiguous strides since the pool allocation is a fresh buffer.
size_t elem_size = ggml_element_size(dst);
int64_t ne[GGML_MAX_DIMS] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] };
size_t nb[GGML_MAX_DIMS];
nb[0] = elem_size;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
nb[i] = nb[i - 1] * ne[i - 1];
}
size_t gelu_out_size = nb[GGML_MAX_DIMS - 1] * ne[GGML_MAX_DIMS - 1];
ggml_cann_pool_alloc gelu_out_alloc(ctx.pool(), gelu_out_size);
acl_tensor_ptr acl_gelu_out = ggml_cann_create_tensor(
gelu_out_alloc.get(), ggml_cann_type_mapping(dst->type), elem_size, ne, nb, GGML_MAX_DIMS);
// V3 adds activateLeft param; true → Gelu(left)*right, matching GGML convention.
// GGML dim 0 → CANN last dim (index GGML_MAX_DIMS-1 = 3 for 4D tensor).
GGML_CANN_CALL_ACLNN_OP(ctx, GeGluV3, acl_src.get(), (int64_t)(GGML_MAX_DIMS - 1), approximate, true,
acl_dst.get(), acl_gelu_out.get());
}
/**
* @brief Repeats elements of a tensor along each dimension according to the
* specified repeat array.
@@ -445,28 +554,33 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes);
void * buffer = temp_buffer_allocator.get();
int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };
size_t div_nb[GGML_MAX_DIMS];
div_nb[0] = sizeof(float);
int64_t norm_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };
size_t norm_nb[GGML_MAX_DIMS];
norm_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
div_nb[i] = div_nb[i - 1] * div_ne[i - 1];
norm_nb[i] = norm_nb[i - 1] * norm_ne[i - 1];
}
acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS);
acl_tensor_ptr acl_norm = ggml_cann_create_tensor(buffer, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
std::vector<int64_t> norm_dims = { 3 };
acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size());
float p_value = 2.0f;
acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get());
GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_norm.get());
// Clamp norm to at least eps: scale = 1/fmaxf(norm, eps)
acl_scalar_ptr acl_min = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT);
float flt_max = FLT_MAX;
acl_scalar_ptr acl_max = ggml_cann_create_scalar(&flt_max, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_div.get(), acl_min.get(), acl_max.get(), acl_div.get());
ggml_cann_pool_alloc clamp_buffer_allocator(ctx.pool());
acl_tensor_ptr acl_clamped;
GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get());
if (eps > 0.0f) {
void * clamp_buf = clamp_buffer_allocator.alloc(n_bytes);
acl_clamped = ggml_cann_create_tensor(clamp_buf, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
acl_scalar_ptr eps_scalar = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, ClampMin, acl_norm.get(), eps_scalar.get(), acl_clamped.get());
}
aclTensor * acl_div_input = acl_clamped ? acl_clamped.get() : acl_norm.get();
GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div_input, acl_dst.get());
}
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -482,56 +596,30 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
logits_nb[1] = logits_nb[0] * logits_ne[0];
acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
size_t log_softmax_type_size = sizeof(float);
int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size;
ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes);
void * log_softmax_buffer = log_softmax_allocator.get();
int64_t log_softmax_ne[] = { nc, nr };
size_t log_softmax_nb[2];
log_softmax_nb[0] = log_softmax_type_size;
log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0];
acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size,
log_softmax_ne, log_softmax_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get());
int64_t labels_ne[] = { nc, nr };
size_t labels_nb[2];
labels_nb[0] = ggml_type_size(src1->type);
labels_nb[1] = labels_nb[0] * labels_ne[0];
acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2);
size_t mul_type_size = sizeof(float);
int64_t mul_n_bytes = nr * nc * mul_type_size;
ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes);
void * mul_buffer = mul_allocator.get();
size_t loss_per_sample_type_size = sizeof(float);
int64_t loss_per_sample_n_bytes = nr * loss_per_sample_type_size;
ggml_cann_pool_alloc loss_per_sample_allocator(ctx.pool(), loss_per_sample_n_bytes);
void * loss_per_sample_buffer = loss_per_sample_allocator.get();
int64_t mul_ne[] = { nc, nr };
size_t mul_nb[2];
mul_nb[0] = mul_type_size;
mul_nb[1] = mul_nb[0] * mul_ne[0];
acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2);
int64_t loss_per_sample_ne[] = { nr };
size_t loss_per_sample_nb[1];
loss_per_sample_nb[0] = loss_per_sample_type_size;
acl_tensor_ptr acl_loss_per_sample = ggml_cann_create_tensor(
loss_per_sample_buffer, ACL_FLOAT, loss_per_sample_type_size, loss_per_sample_ne, loss_per_sample_nb, 1);
GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get());
size_t backprop_n_bytes = nr * nc * sizeof(float);
ggml_cann_pool_alloc backprop_allocator(ctx.pool(), backprop_n_bytes);
void * backprop_buffer = backprop_allocator.get();
acl_tensor_ptr acl_backprop = ggml_cann_create_tensor(backprop_buffer, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
size_t sum_per_sample_type_size = sizeof(float);
int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size;
ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes);
void * sum_per_sample_buffer = sum_per_sample_allocator.get();
int64_t sum_per_sample_ne[] = { nr };
size_t sum_per_sample_nb[1];
sum_per_sample_nb[0] = sum_per_sample_type_size;
acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor(
sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1);
std::vector<int64_t> sum_dims = { 1 };
acl_int_array_ptr dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size());
bool keep_dims = false;
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT,
acl_sum_per_sample.get());
GGML_CANN_CALL_ACLNN_OP(ctx, SoftmaxCrossEntropyWithLogits, acl_logits.get(), acl_labels.get(),
acl_loss_per_sample.get(), acl_backprop.get());
size_t total_sum_type_size = sizeof(float);
int64_t total_sum_n_bytes = 1 * total_sum_type_size;
@@ -547,11 +635,12 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
std::vector<int64_t> total_sum_dims = { 0 };
acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size());
bool keep_dims = false;
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_loss_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,
acl_total_sum.get());
float value = -1.0f / static_cast<float>(nr);
float value = 1.0f / static_cast<float>(nr);
acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);
acl_tensor_ptr acl_dst =
ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1);
@@ -589,6 +678,33 @@ void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
acl_mean_out.get(), acl_rstd_out.get());
}
void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
size_t nb1 = ((int32_t *) dst->op_params)[0];
size_t nb2 = ((int32_t *) dst->op_params)[1];
size_t nb3 = ((int32_t *) dst->op_params)[2];
size_t offset = ((int32_t *) dst->op_params)[3];
bool inplace = (bool) ((int32_t *) dst->op_params)[4];
size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 };
// Create a view of dst at the target offset with src1's dimensions
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1);
if (!inplace) {
// First copy src0 to dst entirely
size_t cpy_size = ggml_nbytes(dst);
ACL_CHECK(
aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
}
// Copy src1 into the target region of dst
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src1.get());
}
void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
@@ -652,6 +768,113 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
aclnn_reduce_sum(ctx, dst, reduce_dims, 4);
}
void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
// GGML cumsum operates along dim 0 (innermost / ne[0]).
// ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0],
// so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor).
GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3,
ggml_cann_type_mapping(dst->type), acl_dst.get());
}
void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular
ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3]
acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1);
acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst);
// mOut: triangular copy of A (required output), same shape as A.
const size_t a_bytes = ggml_nbytes(src0);
ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes);
acl_tensor_ptr acl_m = ggml_cann_create_tensor(
m_alloc.get(), ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
// Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false.
GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve,
acl_b.get(), acl_a.get(), false, false, false,
acl_x.get(), acl_m.get());
}
void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
GGML_ASSERT(src->ne[1] == 1);
const int64_t N = src->ne[0];
const int64_t n_batch = src->ne[2] * src->ne[3];
const size_t nb_f32 = sizeof(float);
// Fill dst with zeros.
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
{
float zero = 0.0f;
acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get());
}
// Copy src vector onto the diagonal of dst via strided views.
// src viewed as [N, n_batch], contiguous strides.
int64_t ne_vec[2] = { N, n_batch };
size_t nb_src_vec[2] = { nb_f32, N * nb_f32 };
// dst diagonal view: stride (N+1)*4 steps along the diagonal.
size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 };
acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2);
acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get());
}
void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
float c = ggml_get_op_params_f32(dst, 0);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get());
}
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
const int64_t S = src->ne[0];
const int64_t n_batch = src->ne[2] * src->ne[3];
const size_t nb_f32 = sizeof(float);
int64_t ne3d[3] = { S, S, n_batch };
size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 };
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
switch (ttype) {
case GGML_TRI_TYPE_LOWER:
// Tril(-1): preserve row > col (strict lower), zero upper + diagonal.
GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get());
break;
case GGML_TRI_TYPE_UPPER_DIAG:
// Triu(0): preserve row <= col (upper + diagonal), zero strict lower.
GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)0, acl_dst.get());
break;
case GGML_TRI_TYPE_UPPER:
// Triu(1): preserve row < col (strict upper), zero lower + diagonal.
GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)1, acl_dst.get());
break;
case GGML_TRI_TYPE_LOWER_DIAG:
// Tril(0): preserve row >= col (lower + diagonal), zero strict upper.
GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)0, acl_dst.get());
break;
default:
GGML_ABORT("unsupported tri type");
}
}
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
@@ -1695,152 +1918,90 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get());
}
/**
* @brief Performs index select operation on a 4D tensor using the CANN backend.
*
* This function applies the `IndexSelect` operation along a specific dimension
* of the source tensor (`src_buffer`) using the indices from the index tensor (`index`).
* It iterates over the last two dimensions of the source tensor, creates the corresponding
* CANN tensors for the source, index, and output slices, and executes the `IndexSelect`
* operation for each slice.
*
* @param ctx The context for CANN backend operations.
* @param src_buffer The source buffer containing the 4D input tensor data.
* @param src_ne The dimensions of the source tensor.
* @param src_nb The strides (byte offsets) of the source tensor.
* @param dst_buffer The destination buffer where the output tensor data will be written.
* @param dst_ne The dimensions of the destination tensor.
* @param dst_nb The strides (byte offsets) of the destination tensor.
* @param index The index tensor specifying the indices to select from the source tensor.
* @param type The data type of the source and destination tensors.
*/
static void aclnn_index_select_4d(ggml_backend_cann_context & ctx,
void * src_buffer,
int64_t * src_ne,
size_t * src_nb,
void * dst_buffer,
int64_t * dst_ne,
size_t * dst_nb,
ggml_tensor * index,
ggml_type type) {
for (int64_t i = 0; i < src_ne[3]; i++) {
for (int64_t j = 0; j < src_ne[2]; j++) {
// src
acl_tensor_ptr acl_src_tensor =
ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
// index
acl_tensor_ptr acl_index = ggml_cann_create_tensor(
(char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
// out
acl_tensor_ptr acl_out =
ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get());
}
}
}
/**
* @brief Performs inplace index copy operation on a 4D tensor using the CANN backend.
*
* This function applies the `IndexCopy` operation along a specific dimension of the
* destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`)
* to positions specified by the index tensor (`index`).
* It iterates over the last two dimensions of the tensors, creates the corresponding
* CANN tensors for source, index, and destination slices, and performs the index copy
* operation for each slice.
*
* @param ctx The context for CANN backend operations.
* @param src_buffer The source buffer containing the 4D input tensor data to be copied.
* @param src_ne The dimensions of the source tensor.
* @param src_nb The strides (byte offsets) of the source tensor.
* @param dst_buffer The destination buffer where values will be copied to.
* @param dst_ne The dimensions of the destination tensor.
* @param dst_nb The strides (byte offsets) of the destination tensor.
* @param index The index tensor specifying target positions in the destination tensor.
* @param type The data type of the source and destination tensors.
*/
static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx,
void * src_buffer,
int64_t * src_ne,
size_t * src_nb,
void * dst_buffer,
int64_t * dst_ne,
size_t * dst_nb,
ggml_tensor * index,
ggml_type type) {
for (int64_t i = 0; i < src_ne[3]; i++) {
for (int64_t j = 0; j < src_ne[2]; j++) {
// src
acl_tensor_ptr acl_src_tensor =
ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
// index
acl_tensor_ptr acl_index = ggml_cann_create_tensor(
(char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
// out
acl_tensor_ptr acl_out =
ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get());
}
}
}
void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // src
ggml_tensor * src0 = dst->src[0]; // weight
ggml_tensor * src1 = dst->src[1]; // index
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
|| dst->type == GGML_TYPE_BF16);
// n_idx: number of row indices per (i2, i3) batch slice.
// ggml guarantees: src0->ne[2] == src1->ne[1], src0->ne[3] == src1->ne[2], src1->ne[3] == 1.
const int64_t n_idx = src1->ne[0];
// Gather all (i2, i3) batch slices from src into dst.
// ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0].
// GatherV2 with dim=0 gathers along ACL dim-0 == ggml ne[1] (the vocabulary / row axis).
// nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
// nb[2..3] for computing per-batch-slice base pointer offsets).
auto gather_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
const size_t * nb) {
int64_t src_ne[2] = { src0->ne[0], src0->ne[1] };
size_t src_nb_2d[2] = { nb[0], nb[1] };
int64_t dst_ne[2] = { src0->ne[0], n_idx };
size_t dst_nb_2d[2] = { dst->nb[0], dst->nb[1] };
int64_t idx_ne[1] = { n_idx };
size_t idx_nb[1] = { (size_t)ggml_element_size(src1) };
for (int64_t i3 = 0; i3 < src0->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < src0->ne[2]; i2++) {
acl_tensor_ptr acl_src = ggml_cann_create_tensor(
(char *)src_base + i3 * nb[3] + i2 * nb[2],
acl_type, type_size, src_ne, src_nb_2d, 2);
acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
(char *)src1->data + i3 * src1->nb[2] + i2 * src1->nb[1],
ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
idx_ne, idx_nb, 1);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
(char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
acl_type, type_size, dst_ne, dst_nb_2d, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, GatherV2, acl_src.get(), 0, acl_idx.get(), acl_dst.get());
}
}
};
switch (src0->type) {
case GGML_TYPE_BF16:
case GGML_TYPE_F16:
case GGML_TYPE_F32:
if (src0->type == dst->type) {
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1,
dst->type);
gather_batched(src0->data,
ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),
src0->nb);
} else {
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
void * src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = dst->nb[0];
// Cast src0 to dst type, then gather.
ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
ggml_nelements(src0) * ggml_element_size(dst));
size_t src_cast_nb[GGML_MAX_DIMS];
src_cast_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
}
acl_tensor_ptr src_trans_tensor =
ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
dst->type);
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src0->ne, src_cast_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
gather_batched(src_cast_allocator.get(),
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src_cast_nb);
}
break;
case GGML_TYPE_Q8_0:
{
// add 1 dim for bcast mul.
// Dequantize Q8_0 to dst type, then gather.
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1];
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne;
int64_t scale_offset = 0;
// [3,4,5,64] -> [3,4,5,2,32]
weight_ne[0] = QK8_0;
weight_ne[1] = src0->ne[0] / QK8_0;
weight_nb[0] = sizeof(int8_t);
weight_nb[1] = weight_nb[0] * weight_ne[0];
weight_ne[0] = QK8_0;
weight_ne[1] = src0->ne[0] / QK8_0;
weight_nb[0] = sizeof(int8_t);
weight_nb[1] = weight_nb[0] * weight_ne[0];
for (int i = 2; i < GGML_MAX_DIMS + 1; i++) {
weight_ne[i] = src0->ne[i - 1];
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
}
// [3,4,5,64] -> [3,4,5,2,1]
scale_ne[0] = 1;
scale_ne[1] = src0->ne[0] / QK8_0;
scale_nb[0] = sizeof(uint16_t);
@@ -1849,31 +2010,33 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
scale_ne[i] = src0->ne[i - 1];
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
}
// [3,4,5,64] -> [3,4,5,2,32]
dequant_ne = weight_ne;
dequant_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
}
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(),
ggml_nelements(src0) * ggml_type_size(dst->type));
acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),
weight_ne, weight_nb, GGML_MAX_DIMS + 1);
acl_tensor_ptr acl_scale_tensor =
ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
acl_tensor_ptr dequant_tensor =
ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get());
dequant_nb[0] = ggml_type_size(dst->type);
const int64_t scale_offset = ggml_nelements(src0) * sizeof(int8_t);
ggml_cann_pool_alloc dequant_allocator(ctx.pool(),
ggml_nelements(src0) * ggml_type_size(dst->type));
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),
weight_ne, weight_nb, GGML_MAX_DIMS + 1);
acl_tensor_ptr acl_scale = ggml_cann_create_tensor(
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
acl_tensor_ptr acl_dequant = ggml_cann_create_tensor(
dequant_allocator.get(), ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
aclnn_mul(ctx, acl_weight.get(), acl_scale.get(), acl_dequant.get());
// Reinterpret dequant buffer as 4D [src0->ne] with contiguous strides.
dequant_ne = src0->ne;
dequant_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
}
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne,
dst->nb, src1, dst->type);
gather_batched(dequant_allocator.get(),
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
dequant_nb);
break;
}
default:
@@ -1883,31 +2046,70 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
}
void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // src
ggml_tensor * src1 = dst->src[1]; // index
ggml_tensor * src0 = dst->src[0]; // source values
ggml_tensor * src1 = dst->src[1]; // row indices
// n_idx: number of source rows to scatter per batch slice.
// ggml guarantees: src0->ne[1] == src1->ne[0].
const int64_t n_idx = src1->ne[0];
// Copy n_idx rows of src [ne0, n_idx] into dst [ne0, ne1] at positions given by a 1D index.
// ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0] for dst.
// InplaceIndexCopy with dim=0 copies along ACL dim-0 == ggml ne[1] (the row axis).
// src_nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
// nb[2..3] for computing per-batch-slice base pointer offsets).
auto scatter_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
const size_t * src_nb) {
int64_t d_ne[2] = { dst->ne[0], dst->ne[1] };
size_t d_nb[2] = { dst->nb[0], dst->nb[1] };
int64_t s_ne[2] = { dst->ne[0], n_idx };
size_t s_nb_2d[2] = { src_nb[0], src_nb[1] };
int64_t i_ne[1] = { n_idx };
size_t i_nb[1] = { (size_t)ggml_element_size(src1) };
for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < dst->ne[2]; i2++) {
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
(char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
acl_type, type_size, d_ne, d_nb, 2);
acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
(char *)src1->data + (i3 % src1->ne[2]) * src1->nb[2] + (i2 % src1->ne[1]) * src1->nb[1],
ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
i_ne, i_nb, 1);
acl_tensor_ptr acl_src = ggml_cann_create_tensor(
(char *)src_base + i3 * src_nb[3] + i2 * src_nb[2],
acl_type, type_size, s_ne, s_nb_2d, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_dst.get(), 0, acl_idx.get(), acl_src.get());
}
}
};
switch (dst->type) {
case GGML_TYPE_F32:
{
aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type);
break;
}
scatter_batched(src0->data,
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src0->nb);
break;
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
{
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
void * src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = sizeof(uint16_t);
// Cast src0 (F32) to dst type first.
ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
ggml_nelements(src0) * ggml_type_size(dst->type));
size_t src_cast_nb[GGML_MAX_DIMS];
src_cast_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
}
acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
dst->type);
acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src0->ne, src_cast_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
scatter_batched(src_cast_allocator.get(),
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src_cast_nb);
break;
}
default:
@@ -3268,29 +3470,50 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst
int64_t paddingsArray[2] = { opts[0], opts[1] };
acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2);
for (int64_t i = 0; i < src0->ne[3]; i++) {
acl_tensor_ptr acl_src =
ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type),
ggml_element_size(src0), src0->ne, src0->nb, 3);
// Collapsing ne[2]*ne[3] into a single batch dimension requires that dim3
// is contiguous with respect to dim2 in both src and dst.
GGML_ASSERT(src0->nb[3] == src0->nb[2] * src0->ne[2]);
GGML_ASSERT(dst->nb[3] == dst->nb[2] * dst->ne[2]);
acl_tensor_ptr acl_dst =
ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type),
ggml_element_size(dst), dst->ne, dst->nb, 3);
int64_t src_ne_3d[3] = { src0->ne[0], src0->ne[1], src0->ne[2] * src0->ne[3] };
int64_t dst_ne_3d[3] = { dst->ne[0], dst->ne[1], dst->ne[2] * dst->ne[3] };
GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());
}
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
ggml_element_size(src0), src_ne_3d, src0->nb, 3);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
ggml_element_size(dst), dst_ne_3d, dst->nb, 3);
GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());
}
void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
// Write element-wise equality (0 or 1) into a temporary buffer to avoid
// modifying src0 in-place. Use the same type as src0 so ReduceSum can
// consume it directly without a type cast.
ggml_cann_pool_alloc eq_alloc(ctx.pool(), ggml_nelements(src0) * ggml_element_size(src0));
size_t eq_nb[GGML_MAX_DIMS];
eq_nb[0] = ggml_element_size(src0);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
eq_nb[i] = eq_nb[i - 1] * src0->ne[i - 1];
}
acl_tensor_ptr acl_eq = ggml_cann_create_tensor(
eq_alloc.get(), ggml_cann_type_mapping(src0->type), ggml_element_size(src0),
src0->ne, eq_nb, GGML_MAX_DIMS);
acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1);
GGML_CANN_CALL_ACLNN_OP(ctx, EqTensor, acl_self.get(), acl_other.get(), acl_eq.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get());
ggml_cann_sum(ctx, dst);
// Sum the 0/1 values into dst.
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
int64_t dims[4] = { 0, 1, 2, 3 };
acl_int_array_ptr dims_arr = ggml_cann_create_int_array(dims, 4);
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_eq.get(), dims_arr.get(), true,
ggml_cann_type_mapping(dst->type), acl_dst.get());
}
void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -3306,6 +3529,27 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get());
}
void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
float beta_val = 1.0f;
float threshold_val = 20.0f;
acl_scalar_ptr beta = ggml_cann_create_scalar(&beta_val, ACL_FLOAT);
acl_scalar_ptr threshold = ggml_cann_create_scalar(&threshold_val, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Softplus, acl_src.get(), beta.get(), threshold.get(), acl_dst.get());
}
void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
auto gelu_quick_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
};
ggml_cann_op_unary_gated(gelu_quick_fn, ctx, dst);
}
/**
* @brief Performs expert-specific matrix multiplication (MoE) with
* floating-point precision using the CANN backend.
@@ -3892,46 +4136,65 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
}
static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // weight
ggml_tensor * src1 = dst->src[1]; // input
ggml_tensor * src0 = dst->src[0]; // weight [ne00=m, ne01=K, ne02, ne03]
ggml_tensor * src1 = dst->src[1]; // input [ne10=n, ne11=K, ne12, ne13]
GGML_TENSOR_BINARY_OP_LOCALS
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
// dst[i,j] = sum_k src0[i,k] * src1[j,k] i.e. dst = src0 @ src1^T.
//
// ggml_cann_create_tensor reverses dimension order, so ACL sees:
// acl_src0 slice: ggml[m,K] -> ACL[K,m]
// acl_src1 slice: ggml[n,K] -> ACL[K,n]
// acl_dst slice: ggml[m,n] -> ACL[n,m]
//
// Build a transposed view of src1 by swapping ne[0]/ne[1]:
// src1_t: ggml[K,n] (swapped strides) -> ACL[n,K]
//
// Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst ✓
//
// The outer batch loop is kept because src0 may have fewer batch slices than
// dst (ne02 <= ne2, ne03 <= ne3): this is a strided-broadcast not supported
// by standard CANN Matmul broadcasting.
const aclDataType src0_acl_type = ggml_cann_type_mapping(src0->type);
const aclDataType src1_acl_type = ggml_cann_type_mapping(src1->type);
const aclDataType dst_acl_type = ggml_cann_type_mapping(dst->type);
const size_t src0_type_sz = ggml_type_size(src0->type);
const size_t src1_type_sz = ggml_type_size(src1->type);
const size_t dst_type_sz = ggml_type_size(dst->type);
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t i02 = i2 / dps2;
const int64_t i03 = i3 / dps3;
const int64_t i12 = i2;
const int64_t i13 = i3;
acl_tensor_ptr accumulator =
ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
// src0 2D slice at [i02, i03]: ggml [m, K] -> ACL [K, m]
int64_t src0_ne[2] = { ne00, ne01 };
size_t src0_nb[2] = { nb00, nb01 };
acl_tensor_ptr acl_src0_s = ggml_cann_create_tensor(
(char *) src0->data + i02 * nb02 + i03 * nb03,
src0_acl_type, src0_type_sz, src0_ne, src0_nb, 2);
// The outer product needs to be accumulated in this dimension.
for (int64_t i1 = 0; i1 < ne11; i1++) {
acl_tensor_ptr acl_input = ggml_cann_create_tensor(
(char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src1->ne, src1->nb, 1);
// src1 transposed 2D slice at [i2, i3]: swap ne/nb -> ggml[K,n] -> ACL[n,K]
int64_t src1_t_ne[2] = { ne11, ne10 };
size_t src1_t_nb[2] = { nb11, nb10 };
acl_tensor_ptr acl_src1_t = ggml_cann_create_tensor(
(char *) src1->data + i2 * nb12 + i3 * nb13,
src1_acl_type, src1_type_sz, src1_t_ne, src1_t_nb, 2);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(
(char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src0->ne, src0->nb, 1);
// dst 2D slice at [i2, i3]: ggml [m, n] -> ACL [n, m]
int64_t dst_ne[2] = { ne0, ne1 };
size_t dst_nb[2] = { nb0, nb1 };
acl_tensor_ptr acl_dst_s = ggml_cann_create_tensor(
(char *) dst->data + i2 * nb2 + i3 * nb3,
dst_acl_type, dst_type_sz, dst_ne, dst_nb, 2);
ggml_cann_pool_alloc output_allocator(ctx.pool());
void * output_buffer = output_allocator.alloc(ggml_nbytes(dst));
acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
float alpha_value = 1.0f;
aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
}
// Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst_s ✓
GGML_CANN_CALL_ACLNN_OP(ctx, Matmul,
acl_src1_t.get(), acl_src0_s.get(), acl_dst_s.get(), (int8_t) 1);
}
}
}
@@ -4170,3 +4433,4 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor *
}
}
}
+56
View File
@@ -32,6 +32,9 @@
#include <aclnnop/aclnn_cat.h>
#include <aclnnop/aclnn_clamp.h>
#include <aclnnop/aclnn_cos.h>
#include <aclnnop/aclnn_cumsum.h>
#include <aclnnop/aclnn_tril.h>
#include <aclnnop/aclnn_triu.h>
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_gelu.h>
#include <aclnnop/aclnn_gelu_v2.h>
@@ -47,6 +50,9 @@
#include <aclnnop/aclnn_sign.h>
#include <aclnnop/aclnn_silu.h>
#include <aclnnop/aclnn_sin.h>
#include <aclnnop/aclnn_softplus.h>
#include <aclnnop/aclnn_swi_glu.h>
#include <aclnnop/aclnn_geglu.h>
#include <aclnnop/aclnn_slice.h>
#include <aclnnop/aclnn_sqrt.h>
#include <aclnnop/aclnn_tanh.h>
@@ -69,6 +75,9 @@
*/
void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate);
/**
* @brief Applies the Leaky ReLU activation function to a tensor using the CANN
* backend.
@@ -325,6 +334,48 @@ void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the cumulative sum of a ggml tensor along dim 0 using the
* CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_CUMSUM`.
*/
void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes a triangular mask (tril/triu) of a square ggml tensor
* using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_TRI`.
*/
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Solves a triangular linear system AX=B using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_SOLVE_TRI`.
*/
void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Creates a diagonal matrix from a vector using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_DIAG`.
*/
void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Fills a tensor with a constant scalar value using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_FILL`.
*/
void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Upsamples a ggml tensor using nearest neighbor interpolation using
* the CANN backend.
@@ -461,6 +512,9 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor *
// @see ggml_cann_dup.
void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst);
// @see ggml_cann_acc, but copies src1 into dst instead of adding.
void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the softmax activation with optional masking.
*
@@ -813,6 +867,8 @@ void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst);
* dst->op is expected to be `GGML_OP_STEP`.
*/
void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Performs the Flash Attention extended operator using the CANN backend.
+56 -10
View File
@@ -1428,6 +1428,22 @@ static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
return false;
}
/**
* @brief Set a region of a tensor's device memory to a specified value.
*
* @param buffer The CANN buffer containing the tensor.
* @param tensor Pointer to the tensor whose memory will be set.
* @param value The value to which each byte in the region will be set.
* @param offset Byte offset within the tensor's data to start setting.
* @param size Number of bytes to set.
*/
static void ggml_backend_cann_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
ggml_cann_set_device(ctx->device);
ACL_CHECK(aclrtMemset((char *) tensor->data + offset, size, value, size));
}
/**
* @brief Clear a CANN buffer by setting all its memory to a specified value.
*
@@ -1454,7 +1470,7 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
/* .get_base = */ ggml_backend_cann_buffer_get_base,
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
/* .memset_tensor = */ NULL,
/* .memset_tensor = */ ggml_backend_cann_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
/* .set_tensor_2d = */ NULL,
@@ -1835,6 +1851,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_UNARY_OP_STEP:
ggml_cann_step(ctx, dst);
break;
case GGML_UNARY_OP_SOFTPLUS:
ggml_cann_softplus(ctx, dst);
break;
default:
return false;
}
@@ -1845,20 +1864,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
GGML_CANN_CALL_OP_UNARY_GATED(Relu);
break;
case GGML_GLU_OP_GEGLU:
ggml_cann_geglu(ctx, dst, 0); // approximate=0 → tanh
break;
case GGML_GLU_OP_GEGLU_ERF:
// aclnnGelu internally uses the erf-based approximation.
GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
ggml_cann_geglu(ctx, dst, 1); // approximate=1 → erf
break;
case GGML_GLU_OP_SWIGLU:
GGML_CANN_CALL_OP_UNARY_GATED(Silu);
ggml_cann_swiglu(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_QUICK:
{
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
};
ggml_cann_op_unary_gated(lambda, ctx, dst);
}
ggml_cann_geglu_quick(ctx, dst);
break;
default:
return false;
@@ -1920,6 +1935,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_CPY:
ggml_cann_cpy(ctx, dst);
break;
case GGML_OP_SET:
ggml_cann_set(ctx, dst);
break;
case GGML_OP_CONT:
ggml_cann_dup(ctx, dst);
break;
@@ -1989,6 +2007,21 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_SSM_CONV:
ggml_cann_ssm_conv(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cann_cumsum(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cann_tri(ctx, dst);
break;
case GGML_OP_FILL:
ggml_cann_fill(ctx, dst);
break;
case GGML_OP_DIAG:
ggml_cann_diag(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cann_solve_tri(ctx, dst);
break;
default:
return false;
}
@@ -2324,6 +2357,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
if (use_cann_graph) {
// If no matching graph is found, the graph needs to be recaptured.
graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
if (graph_capture_required) {
// If no matching graph is found, add a new ACL graph.
ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
@@ -2382,6 +2416,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_SOFTPLUS:
return true;
default:
return false;
@@ -2572,6 +2607,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
case GGML_OP_SET:
case GGML_OP_GROUP_NORM:
return true;
case GGML_OP_PAD:
@@ -2649,6 +2685,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
}
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_CUMSUM:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_TRI:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_FILL:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_DIAG:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOLVE_TRI:
return op->src[0]->type == GGML_TYPE_F32;
default:
return false;
}
+16 -16
View File
@@ -2005,12 +2005,12 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
const int lda = KB * sizeof(TA);
//const int ldb = KB * sizeof(TB);
static thread_local packed_B_t Tile0[TILE_N * TILE_K];
static thread_local packed_B_t Tile1[TILE_N * TILE_K];
static thread_local int8_t Tile23[TILE_M * TILE_K];
alignas(64) static thread_local packed_B_t Tile0[TILE_N * TILE_K];
alignas(64) static thread_local packed_B_t Tile1[TILE_N * TILE_K];
alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K];
static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
alignas(64) static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
alignas(64) static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
// double buffering C to interleave avx512 and amx
int32_t * C_cur = TileC0;
@@ -2187,21 +2187,21 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
const int m1 = std::max(M - TILE_M, 0);
//const int lda = KB * sizeof(TA);
static thread_local int8_t Tile0[TILE_N * TILE_K];
static thread_local int8_t Tile1[TILE_N * TILE_K];
static thread_local int8_t Tile23[TILE_M * TILE_K];
alignas(64) static thread_local int8_t Tile0[TILE_N * TILE_K];
alignas(64) static thread_local int8_t Tile1[TILE_N * TILE_K];
alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K];
// mat mul result for each group
static thread_local int32_t Tile4[TILE_M * TILE_N];
static thread_local int32_t Tile5[TILE_M * TILE_N];
static thread_local int32_t Tile6[TILE_M * TILE_N];
static thread_local int32_t Tile7[TILE_M * TILE_N];
alignas(64) static thread_local int32_t Tile4[TILE_M * TILE_N];
alignas(64) static thread_local int32_t Tile5[TILE_M * TILE_N];
alignas(64) static thread_local int32_t Tile6[TILE_M * TILE_N];
alignas(64) static thread_local int32_t Tile7[TILE_M * TILE_N];
// sum of each QK_K block, contains 8 groups, int32
static thread_local int32_t Sumi4[TILE_M * TILE_N];
static thread_local int32_t Sumi5[TILE_M * TILE_N];
static thread_local int32_t Sumi6[TILE_M * TILE_N];
static thread_local int32_t Sumi7[TILE_M * TILE_N];
alignas(64) static thread_local int32_t Sumi4[TILE_M * TILE_N];
alignas(64) static thread_local int32_t Sumi5[TILE_M * TILE_N];
alignas(64) static thread_local int32_t Sumi6[TILE_M * TILE_N];
alignas(64) static thread_local int32_t Sumi7[TILE_M * TILE_N];
const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;
for (int i = 0; i < KB; ++i) {
+65
View File
@@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n,
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (svcntb() * 8 == 256) {
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7};
svuint32_t idx = svld1(svptrue_b32(), idx_arr);
static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0};
svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1);
static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3};
svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2);
for (int y = 0; y < nr; y += 4) {
const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
for (int x = 0; x < nc; x += ncols_interleaved) {
const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
const block_q8_0x4 * a_ptr = a_ptr_base;
svfloat32_t acc_f32_01 = svdup_f32(0);
svfloat32_t acc_f32_23 = svdup_f32(0);
for (int b = 0; b < nb; b++) {
svint32_t acc_01 = svdup_s32(0);
svint32_t acc_23 = svdup_s32(0);
// Process 4 chunks of 8 positions each
for (int chunk = 0; chunk < 4; chunk++) {
svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32);
svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16);
svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32);
acc_01 = svmmla_s32(acc_01, s_a01, s_b0123);
acc_23 = svmmla_s32(acc_23, s_a23, s_b0123);
}
// Reorder outputs from 2×2 tiles to row-major
// acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3]
// acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3]
svint32_t row01 = svtbl_s32(acc_01, idx);
svint32_t row23 = svtbl_s32(acc_23, idx);
svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d);
svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d);
svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1);
svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2);
acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0));
acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2));
a_ptr++;
b_ptr++;
}
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01);
svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4));
svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23);
svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4));
}
}
return;
}
#endif // SVE compile-time end
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
+12
View File
@@ -830,6 +830,18 @@ static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
}
static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) {
#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only
if (!(x > 0.0f)) {
return 0;
}
const __nv_fp8_e4m3 xf(x);
return xf.__x;
#else
NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell
#endif // defined(BLACKWELL_MMA_AVAILABLE)
}
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
const uint8_t sign_bit = (x < 0.0f) << 3;
float ax = fabsf(x) * e;
+14 -1
View File
@@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
@@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
@@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
@@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col));
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);
@@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build:
extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
// For GLM 4.7 Flash
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
+4
View File
@@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
} break;
case 320: {
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst);
} break;
case 512: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);
+28 -9
View File
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
@@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
@@ -1144,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm
}
}
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
launch_fattn<DV, cols_per_block/ncols2, ncols2>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
return;
if constexpr (ncols2 <= 16) {
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
launch_fattn<DV, cols_per_block/ncols2, ncols2>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
return;
}
}
if constexpr (ncols2 <= 8) {
@@ -1210,6 +1220,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
if constexpr (DKQ == 320) { // Mistral Small 4
if (use_gqa_opt && gqa_ratio % 32 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst);
return;
}
GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32");
}
if constexpr (DKQ == 576) {
if (use_gqa_opt && gqa_ratio % 16 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
@@ -1221,7 +1239,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
}
}
if constexpr (DKQ <= 512) {
if constexpr (DKQ <= 512 && DKQ != 320) {
if (use_gqa_opt && gqa_ratio % 8 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
return;
@@ -1275,5 +1293,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
extern DECL_FATTN_TILE_CASE(112, 112);
extern DECL_FATTN_TILE_CASE(128, 128);
extern DECL_FATTN_TILE_CASE(256, 256);
extern DECL_FATTN_TILE_CASE(320, 256);
extern DECL_FATTN_TILE_CASE(512, 512);
extern DECL_FATTN_TILE_CASE(576, 512);
+24
View File
@@ -143,6 +143,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 320:
// For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build).
GGML_ASSERT(V->ne[0] == 256);
{
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
const bool use_gqa_opt = mask && max_bias == 0.0f;
GGML_ASSERT(use_gqa_opt);
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 32 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst);
}
break;
case 512:
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
@@ -352,6 +368,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_NONE;
}
break;
case 320:
if (V->ne[0] != 256 || !gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
if (gqa_ratio % 32 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 512:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
+22 -12
View File
@@ -1015,25 +1015,35 @@ namespace ggml_cuda_mma {
#endif // AMD_MFMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
uint32_t a_scale,
uint32_t b_scale) {
template <ggml_type type>
static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
uint32_t a_scale,
uint32_t b_scale) {
#ifdef BLACKWELL_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
float * Dxi = (float *) D.x;
asm volatile(
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
"%10, {0, 0}, %11, {0, 0};"
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
if constexpr (type == GGML_TYPE_MXFP4) {
asm volatile(
"mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
"%10, {0, 0}, %11, {0, 0};"
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
} else {
asm volatile(
"mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
"%10, {0, 0}, %11, {0, 0};"
: "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
}
#else
GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
#endif // BLACKWELL_MMA_AVAILABLE
#endif // BLACKWELL_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
+10 -11
View File
@@ -122,7 +122,7 @@ void ggml_cuda_mul_mat_q(
|| GGML_CUDA_CC_IS_CDNA(cc);
// TODO: tighter pool buffer size vs q8 path
const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4);
if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -133,9 +133,9 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
if (use_native_fp4) {
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
ne11, ne12, ne13, stream);
} else {
@@ -146,10 +146,8 @@ void ggml_cuda_mul_mat_q(
}
// Stride depends on quantization format
const int64_t s12 = use_native_mxfp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) /
(8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
:
const int64_t s12 = use_native_fp4 ?
ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
@@ -198,8 +196,8 @@ void ggml_cuda_mul_mat_q(
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
if (use_native_mxfp4) {
quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
if (use_native_fp4) {
quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
} else {
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
@@ -208,8 +206,9 @@ void ggml_cuda_mul_mat_q(
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4");
const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) :
ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
const int64_t s13 = ne12*s12;
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
+147 -83
View File
@@ -10,9 +10,9 @@
using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_ITER_K_MXFP4_FP4 512
#define MMQ_NWARPS 8
#define MMQ_ITER_K 256
#define MMQ_ITER_K_FP4 512
#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
@@ -46,9 +46,12 @@ struct block_q8_1_mmq {
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
};
// this struct is used for fp4 data types (currently only used for Blackwell)
// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits
// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales
struct block_fp4_mmq {
uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
uint32_t d4[4];
int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
};
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
@@ -143,10 +146,11 @@ static int get_mmq_y_host(const int cc) {
static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
#if defined(BLACKWELL_MMA_AVAILABLE)
return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
#else
return MMQ_ITER_K;
if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {
return MMQ_ITER_K_FP4;
}
#endif // defined(BLACKWELL_MMA_AVAILABLE)
return MMQ_ITER_K;
}
static constexpr __device__ int get_mmq_y_device() {
@@ -213,8 +217,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
}
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@@ -240,7 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
// tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
#if defined(BLACKWELL_MMA_AVAILABLE)
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4;
#else
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
#endif // defined(BLACKWELL_MMA_AVAILABLE)
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -934,6 +942,128 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
}
}
#ifdef BLACKWELL_MMA_AVAILABLE
template <int mmq_y, bool need_check>
static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x,
int * __restrict__ x_tile,
const int kbx0,
const int i_max,
const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4);
constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block
constexpr int rows_per_warp = warp_size / threads_per_row;
uint32_t * x_u32 = (uint32_t *) x_tile;
const int txi = threadIdx.x;
const int kbx = txi % threads_per_row;
const int row_in_warp = txi / threads_per_row;
const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx;
uint32_t * x_u32_scale = x_u32 + 64 + kbx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
if constexpr (need_check) {
i = min(i, i_max);
}
const block_nvfp4 * bxi = bxi_base + i * stride;
const int row_base = i * MMQ_MMA_TILE_X_K_FP4;
const int q_base = row_base + 8 * kbx;
const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
#pragma unroll
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0];
x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1];
}
x_u32_scale[row_base] = get_int_b4(bxi->d, 0);
}
}
// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell.
// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per
// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3)
// and the per-type stride constant differ.
template <int mmq_x, int mmq_y, ggml_type type>
static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x,
const int * __restrict__ y,
float * __restrict__ sum,
const int k00) {
static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4,
"vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4");
typedef tile<16, 8, int> tile_A;
typedef tile<8, 8, int> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int stride = MMQ_MMA_TILE_X_K_FP4;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp / tile_C::I;
constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J;
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int * y_qs = (const int *) y + 4;
const uint32_t * y_sc = (const uint32_t *) y;
// 2 threads per quad supply the packed scale register to the block_scale MMA,
// see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
const int tidx_B = threadIdx.x / 4;
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
tile_A A[ntx][nfrags];
uint32_t scaleA[ntx][nfrags];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int frag = 0; frag < nfrags; ++frag) {
const int k0 = k00 + frag * tile_A::J;
load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride);
scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J];
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
tile_B B[nfrags];
uint32_t scaleB[nfrags];
#pragma unroll
for (int frag = 0; frag < nfrags; ++frag) {
const int k0 = frag * tile_B::J;
load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int frag = 0; frag < nfrags; ++frag) {
tile_C C = {};
mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
}
}
}
}
}
#endif // BLACKWELL_MMA_AVAILABLE
template <int mmq_y, bool need_check>
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
@@ -1163,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
const int * __restrict__ y,
float * __restrict__ sum,
const int k00) {
typedef tile<16, 8, int> tile_A;
typedef tile<8, 8, int> tile_B;
typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
// Match layout from load_tiles_mxfp4_fp4
const int * x_qs = (const int *) x;
const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
const int * y_qs = (const int *) y + 4;
const uint32_t * y_sc = (const uint32_t *) y;
// tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
// Block scale
// Each thread has to point to a 4 byte scale value
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
const int k0 = k00 + k01;
load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
MMQ_MMA_TILE_X_K_FP4);
// based on block-scaling document, 2 threads in each quad need to supply to the scale value
const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
scaleA[n][k01 / (2 * QI_MXFP4)] =
*(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
tile_B B;
uint32_t scaleB; // 2xN scales
load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
tile_C C;
mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
}
}
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
@@ -3259,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
#ifdef BLACKWELL_MMA_AVAILABLE
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>;
#else
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3270,8 +3329,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
#ifdef BLACKWELL_MMA_AVAILABLE
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>;
#else
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
#endif // BLACKWELL_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
@@ -3406,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
#if defined(BLACKWELL_MMA_AVAILABLE)
// FP4 tile stores 8 blocks
constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1;
#else
constexpr int ne_block = 4 * QK8_1;
#endif // defined(BLACKWELL_MMA_AVAILABLE)
+3
View File
@@ -115,6 +115,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(gg
case GGML_TYPE_IQ4_NL: return 6;
case GGML_TYPE_IQ4_XS: return 5;
case GGML_TYPE_MXFP4: return 4;
case GGML_TYPE_NVFP4: return 4;
case GGML_TYPE_Q2_K: return 4;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 6;
@@ -135,6 +136,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggm
case GGML_TYPE_IQ3_S: return 6;
case GGML_TYPE_IQ3_XXS: return 7;
case GGML_TYPE_MXFP4: return 7;
case GGML_TYPE_NVFP4: return 8;
case GGML_TYPE_Q2_K: return 7;
case GGML_TYPE_Q3_K: return 5;
default: return MMVQ_MAX_BATCH_SIZE;
@@ -221,6 +223,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
case GGML_TYPE_IQ4_NL: return 7;
case GGML_TYPE_IQ4_XS: return 5;
case GGML_TYPE_MXFP4: return 5;
case GGML_TYPE_NVFP4: return 5;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 7;
case GGML_TYPE_Q4_1: return 7;
+121 -21
View File
@@ -70,6 +70,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
return static_cast<uint8_t>(biased);
}
static __global__ void quantize_mmq_nvfp4(
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
#if defined(BLACKWELL_MMA_AVAILABLE)
const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB;
if (i0_base >= ne0) {
return;
}
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t k_block = i0_base / QK_K;
const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K;
if (k_block >= blocks_per_col) {
return;
}
const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x;
block_fp4_mmq * y = (block_fp4_mmq *) vy;
block_fp4_mmq * yb = y + ib;
const int sub = (i0_base % QK_K) / QK_NVFP4_SUB;
float vals_raw[QK_NVFP4_SUB];
float amax_raw = 0.0f;
const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01;
#pragma unroll
for (int k = 0; k < QK_NVFP4_SUB; k++) {
const int64_t i00 = i0_base + k;
if (i00 < ne00) {
const float v = x[base_idx + i00];
vals_raw[k] = v;
amax_raw = fmaxf(amax_raw, fabsf(v));
} else {
vals_raw[k] = 0.0f;
}
}
static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2};
const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f);
float best_err = FLT_MAX;
uint8_t fp8_code = 0;
float subblock_scale = 0.0f;
#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell.
for (int i = 0; i < 5; i++) {
const int test_code = first_fp8_code + test_offsets[i];
if (test_code < 0 || test_code > 0x7e) {
continue;
}
const uint8_t code = (uint8_t) test_code;
const float test_scale = ggml_cuda_ue4m3_to_fp32(code);
const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f;
float cur_err = 0.0f;
#pragma unroll
for (int k = 0; k < QK_NVFP4_SUB; ++k) {
const float v = vals_raw[k];
const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale);
const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale;
cur_err = fmaf(err_diff, err_diff, cur_err);
}
if (cur_err < best_err) {
best_err = cur_err;
fp8_code = test_code;
subblock_scale = test_scale;
}
}
const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f;
uint32_t q0 = 0;
uint32_t q1 = 0;
#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1
for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) {
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k);
q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4);
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k);
q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4);
}
uint32_t * yqs = reinterpret_cast<uint32_t *>(yb->qs);
yqs[2 * sub + 0] = q0;
yqs[2 * sub + 1] = q1;
reinterpret_cast<uint8_t *>(yb->d4)[sub] = fp8_code;
#else
NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only.
#endif // defined(BLACKWELL_MMA_AVAILABLE)
}
// quantize values in the format mxfp4 is stored which is interleaved nibbles
// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
@@ -316,28 +412,32 @@ void quantize_mmq_q8_1_cuda(
}
}
void quantize_mmq_mxfp4_cuda(const float * x,
const int32_t * ids,
void * vy,
[[maybe_unused]] const ggml_type type_src0,
const int64_t ne00,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t ne0,
const int64_t ne1,
const int64_t ne2,
const int64_t ne3,
cudaStream_t stream) {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
void quantize_mmq_fp4_cuda(
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4);
GGML_ASSERT(ne0 > 0);
constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4;
constexpr int vals_per_block = nwarps * vals_per_warp;
if (type_src0 == GGML_TYPE_NVFP4) {
GGML_ASSERT(ne00 % QK_NVFP4 == 0);
constexpr int nvfp4_block_size = 128;
const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size);
const dim3 block_size(nvfp4_block_size, 1, 1);
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
quantize_mmq_nvfp4<<<num_blocks, block_size, 0, stream>>>(
x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
} else {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(WARP_SIZE, nwarps, 1);
constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4;
constexpr int vals_per_block = nwarps * vals_per_warp;
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(WARP_SIZE, nwarps, 1);
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
}
}
+1 -1
View File
@@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda(
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_mmq_mxfp4_cuda(const float * x,
void quantize_mmq_fp4_cuda(const float * x,
const int32_t * ids,
void * vy,
ggml_type type_src0,
@@ -2,4 +2,5 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
@@ -2,4 +2,5 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(320, 256);
@@ -3,7 +3,7 @@
from glob import glob
import os
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
@@ -62,7 +62,7 @@ for filename in glob("*.cu"):
os.remove(filename)
for head_size_kq in HEAD_SIZES_KQ:
head_size_v = head_size_kq if head_size_kq != 576 else 512
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
@@ -84,13 +84,16 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq == 72:
continue
if head_size_kq == 512 and ncols2 not in (4, 8):
# Skip compilation of unused ncols2 values for niche head sizes:
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
continue
if head_size_kq != 576 and ncols2 in (16, 32):
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
continue
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
continue
head_size_v = head_size_kq if head_size_kq != 576 else 512
if head_size_kq not in (320, 576) and ncols2 in (16, 32):
continue
head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512)
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
for type in TYPES_MMQ:
+1 -1
View File
@@ -1101,7 +1101,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
fs::path cache_file = fs::path(cache_dir) / hash_str;
std::ofstream ofs(cache_file, std::ios::binary);
ofs.write((const char *)data, size);
GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.string().c_str());
}
ggml_backend_tensor_set(tensor, data, offset, size);
return true;
+15 -5
View File
@@ -20,12 +20,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
#include <vulkan/vulkan.hpp>
// SPIRV-Headers: LunarG Windows SDK uses Include/spirv-headers/spirv.hpp (not spirv/unified1/). MinGW/MSYS2 and
// Linux packages use Khronos layout spirv/unified1/spirv.hpp. See docs/build.md#vulkan.
#if defined(_WIN32) && !defined(__MINGW32__)
#include <spirv-headers/spirv.hpp>
// SPIR-V Headers: different SDK installations expose different include paths.
// LunarG Vulkan SDK on Windows typically provides <spirv-headers/spirv.hpp>.
// Linux packages, MSYS2 and MinGW often use the Khronos layout <spirv/unified1/spirv.hpp>.
#if __has_include(<spirv/unified1/spirv.hpp>)
# include <spirv/unified1/spirv.hpp>
#elif __has_include(<spirv-headers/spirv.hpp>)
# include <spirv-headers/spirv.hpp>
#elif __has_include(<spirv.hpp>)
# include <spirv.hpp>
#else
#include <spirv/unified1/spirv.hpp>
// Fallback to let the compiler throw a standard "file not found" error
# include <spirv/unified1/spirv.hpp>
#endif
#include <algorithm>
@@ -13007,6 +13014,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
ctx->query_node_idx[ctx->query_idx] = node_idx;
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
ggml_vk_sync_buffers(ctx, compute_ctx);
}
}
// Add all fused nodes to the unsynchronized lists.
@@ -14496,6 +14504,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
compute_ctx = ggml_vk_get_compute_ctx(ctx);
ctx->query_idx = 0;
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
ggml_vk_sync_buffers(ctx, compute_ctx);
}
ctx->prealloc_y_last_pipeline_used = nullptr;
@@ -14732,6 +14741,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
ctx->query_fusion_names[ctx->query_idx] = fusion_string;
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
ggml_vk_sync_buffers(ctx, compute_ctx);
} else {
// track a fusion string and number of fused ops for the current node_idx
ctx->query_fusion_names[i] = fusion_string;
@@ -296,13 +296,22 @@ vec2 get_dm_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
const uvec3 scales = uvec3(data_a_packed32[ib_k].scales[0],
data_a_packed32[ib_k].scales[1],
data_a_packed32[ib_k].scales[2]);
const uint scalesoffs = (is & 3) * 8;
const uint scidx0 = (is < 4) ? 0 : 2;
const uint scidxshift0 = scalesoffs;
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint mbidx0 = (is < 4) ? 1 : 2;
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
u8vec2 scale_dm = u8vec2(sc, mbyte);
return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm);
}
@@ -201,19 +201,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uvec3 scales = uvec3(data_a_packed32[ib].scales[0],
data_a_packed32[ib].scales[1],
data_a_packed32[ib].scales[2]);
const uint scalesoffs = (is & 3) * 8;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const uint scidx0 = (is < 4) ? 0 : 2;
const uint scidxshift0 = scalesoffs;
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint mbidx0 = (is < 4) ? 1 : 2;
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
@@ -237,19 +238,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uvec3 scales = uvec3(data_a_packed32[ib].scales[0],
data_a_packed32[ib].scales[1],
data_a_packed32[ib].scales[2]);
const uint scalesoffs = (is & 3) * 8;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const uint scidx0 = (is < 4) ? 0 : 2;
const uint scidxshift0 = scalesoffs;
const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint mbidx0 = (is < 4) ? 1 : 2;
const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4;
const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2;
const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30));
const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
+131 -55
View File
@@ -26,21 +26,21 @@
// Matrix multiplication parameters
// Register tiling parameters
#define WEBGPU_MUL_MAT_TILE_M 4
#define WEBGPU_MUL_MAT_TILE_N 4
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
#define WEBGPU_MUL_MAT_TILE_M 4
#define WEBGPU_MUL_MAT_TILE_N 4
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
// Subgroup matrix parameters
// The number of subgroups in the M dimension
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
// The number of subgroups in the N dimension
#define WEBGPU_MUL_MAT_SUBGROUP_N 4
#define WEBGPU_MUL_MAT_SUBGROUP_N 4
// The number of subgroup matrices each subgroup accumulates over
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
@@ -59,19 +59,32 @@ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
// Calculates base address of a tensor ignoring the fake base pointer
inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) {
const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
return (uintptr_t) base_tensor->data + tensor->view_offs;
}
inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) {
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b);
}
inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) {
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) &&
ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a);
}
struct ggml_webgpu_shader_lib_context {
ggml_tensor * src0;
ggml_tensor * src1;
ggml_tensor * src2;
ggml_tensor * src3;
ggml_tensor * src4;
ggml_tensor * src5;
ggml_tensor * dst;
uint32_t max_wg_size;
size_t wg_mem_limit_bytes = 0;
bool inplace = false;
bool overlap = false;
bool src_overlap = false;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
@@ -88,6 +101,14 @@ struct webgpu_pipeline {
struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size = 0;
bool inplace = false;
};
struct ggml_webgpu_binary_shader_decisions {
uint32_t wg_size = 0;
bool inplace = false;
bool overlap = false;
bool src_overlap = false;
};
struct ggml_webgpu_processed_shader {
@@ -102,11 +123,12 @@ struct ggml_webgpu_ssm_conv_shader_decisions {
};
struct ggml_webgpu_ssm_scan_pipeline_key {
int type;
int d_state;
int type;
int d_state;
bool xbc_overlap;
bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
return type == other.type && d_state == other.d_state;
return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap;
}
};
@@ -115,6 +137,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.d_state);
ggml_webgpu_hash_combine(seed, key.xbc_overlap);
return seed;
}
};
@@ -122,6 +145,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash {
struct ggml_webgpu_ssm_scan_shader_decisions {
uint32_t wg_size;
uint32_t tokens_per_tile;
bool xbc_overlap = false;
};
/** Argsort **/
@@ -242,6 +266,13 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
}
};
struct ggml_webgpu_rms_norm_mul_shader_decisions {
uint32_t wg_size = 0;
bool inplace = false;
bool overlap = false;
bool src_overlap = false;
};
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
@@ -503,11 +534,12 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
};
struct ggml_webgpu_flash_attn_decisions {
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
uint32_t q_tile = 0;
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
bool kv_direct = false;
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
uint32_t q_tile = 0;
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
bool kv_direct = false;
bool kv_overlap = false;
};
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
@@ -552,7 +584,7 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_
key.head_dim_qk = (uint32_t) context.src0->ne[0];
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.kv_direct = kv_direct;
key.kv_overlap = context.src_overlap;
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = has_mask;
key.has_sinks = has_sinks;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
@@ -1021,7 +1053,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_row_norm_pipeline_key key = {};
key.op = context.dst->op;
key.inplace = context.inplace;
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = row_norm_pipelines.find(key);
if (it != row_norm_pipelines.end()) {
@@ -1051,8 +1083,12 @@ class ggml_webgpu_shader_lib {
const uint32_t row_norm_wg_size = 128u;
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = wg_size;
decisions->inplace = key.inplace;
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
row_norm_pipelines[key].context = decisions;
return row_norm_pipelines[key];
}
@@ -1127,7 +1163,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_pipeline_key key = {};
key.type = context.dst->type;
key.inplace = context.inplace;
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = set_pipelines.find(key);
if (it != set_pipelines.end()) {
@@ -1160,6 +1196,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_set, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
set_pipelines[key] = pipeline;
@@ -1287,6 +1324,7 @@ class ggml_webgpu_shader_lib {
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (key.src_type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
@@ -1323,7 +1361,9 @@ class ggml_webgpu_shader_lib {
defines.push_back("DST_TYPE=f32");
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
if (key.src_type == GGML_TYPE_Q1_0) {
defines.push_back("BLOCK_SIZE=128u");
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
key.src_type == GGML_TYPE_IQ4_NL) {
defines.push_back("BLOCK_SIZE=32u");
} else if (key.src_type >= GGML_TYPE_Q2_K) {
@@ -1352,7 +1392,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_scale_pipeline_key key = {};
key.inplace = context.inplace;
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = scale_pipelines.find(key);
if (it != scale_pipelines.end()) {
@@ -1372,6 +1412,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_scale, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
scale_pipelines[key] = pipeline;
@@ -1465,6 +1506,8 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_ssm_scan_pipeline_key key = {};
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
@@ -1496,12 +1539,17 @@ class ggml_webgpu_shader_lib {
variant += "_wg_reduce";
}
if (key.xbc_overlap) {
defines.push_back("XBC_OVERLAP");
}
variant += "_d" + std::to_string(key.d_state);
auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
decisions->wg_size = wg_size;
decisions->tokens_per_tile = tokens_per_tile;
decisions->xbc_overlap = key.xbc_overlap;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
ssm_scan_pipelines[key] = pipeline;
@@ -1615,6 +1663,24 @@ class ggml_webgpu_shader_lib {
defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
switch (context.src0->type) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
defines.push_back(type_upper + "_GRID");
break;
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
defines.push_back(type_upper + "_GRID");
defines.push_back(type_upper + "_TABLES");
break;
default:
break;
}
break;
}
}
@@ -1639,7 +1705,9 @@ class ggml_webgpu_shader_lib {
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
if (key.src0_type >= GGML_TYPE_Q2_K) {
if (key.src0_type == GGML_TYPE_Q1_0) {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
@@ -1741,11 +1809,9 @@ class ggml_webgpu_shader_lib {
uint32_t tile_k;
if (key.use_subgroup_matrix) {
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT
: WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
} else {
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
}
// Tiles
@@ -1978,9 +2044,8 @@ class ggml_webgpu_shader_lib {
defines.push_back("SCALAR");
// mul_mat_id is register-tile only.
const uint32_t tile_k = ggml_is_quantized(context.src0->type)
? WEBGPU_MUL_MAT_REG_TILE_K_QUANT
: WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
const uint32_t tile_k =
ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
// Tiles
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
@@ -2016,8 +2081,8 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.op = op;
key.is_unary = is_unary;
key.inplace = context.inplace;
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL;
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
auto it = unary_pipelines.find(key);
if (it != unary_pipelines.end()) {
@@ -2075,6 +2140,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_unary, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
unary_pipelines[key] = pipeline;
@@ -2083,9 +2149,9 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
key.inplace = context.inplace;
key.overlap = context.overlap;
key.src_overlap = context.src_overlap;
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
auto it = rms_norm_mul_pipelines.find(key);
if (it != rms_norm_mul_pipelines.end()) {
@@ -2109,12 +2175,15 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
rms_norm_mul_pipelines[key] = pipeline;
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
auto pipeline_decisions = std::make_shared<ggml_webgpu_rms_norm_mul_shader_decisions>();
pipeline_decisions->wg_size = context.max_wg_size;
pipeline_decisions->inplace = key.inplace;
pipeline_decisions->overlap = key.overlap;
pipeline_decisions->src_overlap = key.src_overlap;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = pipeline_decisions;
rms_norm_mul_pipelines[key] = pipeline;
return rms_norm_mul_pipelines[key];
}
@@ -2122,9 +2191,9 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_binary_pipeline_key key = {};
key.type = context.dst->type;
key.op = context.dst->op;
key.inplace = context.inplace;
key.overlap = context.overlap;
key.src_overlap = context.src_overlap;
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
auto it = binary_pipelines.find(key);
if (it != binary_pipelines.end()) {
@@ -2163,11 +2232,15 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_binary, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
auto processed = preprocessor.preprocess(wgsl_binary, defines);
auto pipeline_decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
pipeline_decisions->wg_size = context.max_wg_size;
pipeline_decisions->inplace = key.inplace;
pipeline_decisions->overlap = key.overlap;
pipeline_decisions->src_overlap = key.src_overlap;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
pipeline.context = pipeline_decisions;
binary_pipelines[key] = pipeline;
return binary_pipelines[key];
}
@@ -2328,7 +2401,8 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
}
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
pipeline_decisions->kv_overlap = key.kv_overlap;
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile));
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size));
@@ -2520,7 +2594,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rope_pipeline_key key = {};
key.type = context.dst->type;
key.inplace = context.inplace;
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
key.has_ff = (context.src2 != nullptr);
auto it = rope_pipelines.find(key);
@@ -2559,6 +2633,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_rope, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
rope_pipelines[key] = pipeline;
@@ -2570,7 +2645,7 @@ class ggml_webgpu_shader_lib {
key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
key.has_mask = (context.src1 != nullptr);
key.has_sink = (context.src2 != nullptr);
key.inplace = context.inplace;
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
auto it = soft_max_pipelines.find(key);
if (it != soft_max_pipelines.end()) {
@@ -2611,6 +2686,7 @@ class ggml_webgpu_shader_lib {
auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
decisions->inplace = key.inplace;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
soft_max_pipelines[key] = pipeline;
+182 -170
View File
@@ -108,12 +108,9 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) {
// their locations.
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
// Always returns the base offset of a tensor, regardless of views.
static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
if (tensor->view_src) {
return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
}
return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
return (size_t) ((uintptr_t) base_tensor->data - (uintptr_t) webgpu_ptr_base) + tensor->view_offs;
}
/* Struct definitions */
@@ -375,10 +372,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
buffer = device.CreateBuffer(&buffer_desc);
}
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
return webgpu_tensor_offset(tensor) + tensor->view_offs;
}
static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
return ctx->buffer;
@@ -398,34 +391,31 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
}
// Used to determine if two tensors are the same for in-place operations
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
}
// Used to determine if two tensors share the same buffer and their byte ranges overlap,
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
}
struct binary_overlap_flags {
bool inplace; // src0 == dst
bool overlap; // src1 == dst
bool src_overlap;
struct ggml_webgpu_merged_binding_range {
size_t offset;
size_t size;
};
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
binary_overlap_flags flags = {};
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
static ggml_webgpu_merged_binding_range ggml_webgpu_tensor_merged_binding_range(
webgpu_context & ctx,
std::initializer_list<ggml_tensor *> tensors) {
size_t merged_offset = SIZE_MAX;
size_t merged_end = 0;
return flags;
for (ggml_tensor * tensor : tensors) {
const size_t bind_offset = ggml_webgpu_tensor_align_offset(ctx, tensor);
const size_t bind_end = bind_offset + ggml_webgpu_tensor_binding_size(ctx, tensor);
merged_offset = std::min(merged_offset, bind_offset);
merged_end = std::max(merged_end, bind_end);
}
return { merged_offset, merged_end - merged_offset };
}
static uint32_t ggml_webgpu_tensor_merged_element_offset(const ggml_tensor * tensor,
const ggml_webgpu_merged_binding_range & merged_range) {
return (uint32_t) ((ggml_webgpu_tensor_offset(tensor) - merged_range.offset) / ggml_type_size(tensor->type));
}
static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding,
@@ -753,18 +743,16 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const bool inplace = decisions->inplace;
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
@@ -1126,19 +1114,39 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.src4 = src4;
shader_lib_ctx.src5 = src5;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx);
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_ssm_scan_shader_decisions *>(pipeline.context.get());
const bool xbc_overlap = decisions->xbc_overlap;
uint32_t offset_x = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
uint32_t offset_B = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type));
uint32_t offset_C = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type));
size_t xbc_bind_offset = 0;
size_t xbc_bind_size = 0;
if (xbc_overlap) {
const ggml_webgpu_merged_binding_range merged_range =
ggml_webgpu_tensor_merged_binding_range(ctx, { src1, src4, src5 });
xbc_bind_offset = merged_range.offset;
xbc_bind_size = merged_range.size;
offset_x = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
offset_B = ggml_webgpu_tensor_merged_element_offset(src4, merged_range);
offset_C = ggml_webgpu_tensor_merged_element_offset(src5, merged_range);
}
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
offset_x,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)),
offset_B,
offset_C,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -1174,11 +1182,24 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx,
};
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
};
if (xbc_overlap) {
entries.push_back(
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), xbc_bind_offset, xbc_bind_size));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src6));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, dst));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst));
}
const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]);
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
@@ -1389,8 +1410,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q1_0:
use_fast = true;
break;
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
use_fast = is_vec;
break;
default:
break;
}
@@ -1641,23 +1674,38 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = Q;
shader_lib_ctx.src1 = K;
shader_lib_ctx.src2 = V;
shader_lib_ctx.src3 = mask;
shader_lib_ctx.src4 = sinks;
shader_lib_ctx.dst = dst;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
const int has_mask = (mask != nullptr);
const int has_sinks = (sinks != nullptr);
const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type;
const bool kv_overlap = decisions->kv_overlap;
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
size_t kv_bind_offset = 0;
size_t kv_bind_size = 0;
if (kv_overlap) {
const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K);
const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V);
const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K);
const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V);
kv_bind_offset = std::min(k_bind_offset, v_bind_offset);
kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset;
offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type));
offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type));
const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V });
kv_bind_offset = merged_range.offset;
kv_bind_size = merged_range.size;
offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range);
offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range);
}
std::vector<uint32_t> params = {
@@ -1708,26 +1756,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
}
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst));
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = Q;
shader_lib_ctx.src1 = K;
shader_lib_ctx.src2 = V;
shader_lib_ctx.src3 = mask;
shader_lib_ctx.src4 = sinks;
shader_lib_ctx.dst = dst;
shader_lib_ctx.src_overlap = kv_overlap;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
@@ -1909,18 +1937,17 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_unary = dst->op == GGML_OP_UNARY;
bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src;
shader_lib_ctx.src1 = nullptr;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const bool inplace = decisions->inplace;
uint32_t ne = (uint32_t) ggml_nelements(dst);
@@ -1982,41 +2009,38 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = flags.inplace;
shader_lib_ctx.overlap = flags.overlap;
shader_lib_ctx.src_overlap = flags.src_overlap;
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst);
size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
uint32_t offset_merged_src0 = 0;
uint32_t offset_merged_src1 = 0;
if (flags.src_overlap) {
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type));
uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
size_t merged_offset = 0;
size_t merged_size = 0;
if (decisions->src_overlap) {
const ggml_webgpu_merged_binding_range merged_range =
ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 });
merged_offset = merged_range.offset;
merged_size = merged_range.size;
offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range);
offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
}
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
offset_src0,
offset_src1,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
offset_merged_src0,
offset_merged_src1,
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
@@ -2036,12 +2060,9 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
std::vector<wgpu::BindGroupEntry> entries;
if (flags.src_overlap) {
size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset,
merged_end - merged_offset));
if (decisions->src_overlap) {
entries.push_back(
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} else {
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0),
@@ -2050,7 +2071,7 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1),
src1_webgpu_tensor_align_offset,
ggml_webgpu_tensor_binding_size(ctx, src1)));
if (!flags.inplace && !flags.overlap) {
if (!decisions->inplace && !decisions->overlap) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
}
}
@@ -2156,29 +2177,15 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
}
bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
bool inplace = ggml_webgpu_tensor_equal(rn_src, dst);
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
uint32_t offset_merged_rn_src = 0;
uint32_t offset_merged_mul_src = 0;
size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src);
size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src);
if (src_overlap) {
size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
offset_merged_rn_src =
(uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type));
offset_merged_mul_src =
(uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type));
}
uint32_t offset_rn_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type));
uint32_t offset_mul_src =
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type));
size_t merged_offset = 0;
size_t merged_size = 0;
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)),
offset_merged_rn_src,
offset_merged_mul_src,
offset_rn_src,
offset_mul_src,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)),
(uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)),
@@ -2202,16 +2209,32 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
std::vector<wgpu::BindGroupEntry> entries;
if (inplace || overlap) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = rn_src;
shader_lib_ctx.src1 = mul_src;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_rms_norm_mul_shader_decisions *>(pipeline.context.get());
if (decisions->src_overlap) {
const ggml_webgpu_merged_binding_range merged_range =
ggml_webgpu_tensor_merged_binding_range(ctx, { rn_src, mul_src });
merged_offset = merged_range.offset;
merged_size = merged_range.size;
offset_rn_src = ggml_webgpu_tensor_merged_element_offset(rn_src, merged_range);
offset_mul_src = ggml_webgpu_tensor_merged_element_offset(mul_src, merged_range);
params[0] = offset_rn_src;
params[1] = offset_mul_src;
}
if (decisions->inplace || decisions->overlap) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
} else if (src_overlap) {
size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset);
size_t merged_end =
std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src),
mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src));
entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset,
merged_end - merged_offset));
} else if (decisions->src_overlap) {
entries.push_back(
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, merged_size));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
@@ -2219,20 +2242,10 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
shader_lib_ctx.overlap = overlap;
shader_lib_ctx.src_overlap = src_overlap;
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst));
}
static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool inplace = ggml_webgpu_tensor_equal(src, dst);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -2249,18 +2262,18 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor
ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader
};
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
if (!inplace) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
if (!decisions->inplace) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src));
}
@@ -2275,14 +2288,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx,
shader_lib_ctx.src2 = src2;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst);
webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_freq_factor = (src2 != nullptr);
const bool inplace = decisions->inplace;
const int has_freq_factor = (src2 != nullptr);
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
@@ -2409,14 +2421,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx,
}
static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool inplace = ggml_webgpu_tensor_equal(src, dst);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src;
shader_lib_ctx.src1 = nullptr;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = inplace;
webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@@ -2442,7 +2451,7 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s
// bindgroups unchanged
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) };
if (!inplace) {
if (!decisions->inplace) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
@@ -2461,17 +2470,17 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx,
shader_lib_ctx.src2 = src2;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst);
webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_mask = (src1 != nullptr);
const int has_sink = (src2 != nullptr);
float max_bias = ggml_get_op_params_f32(dst, 1);
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const bool inplace = decisions->inplace;
const int has_mask = (src1 != nullptr);
const int has_sink = (src2 != nullptr);
float max_bias = ggml_get_op_params_f32(dst, 1);
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
@@ -3067,7 +3076,7 @@ static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend,
size_t size) {
GGML_UNUSED(backend);
auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
// Write aligned portion
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
@@ -3149,7 +3158,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
<< ", " << offset << ", " << size << ")");
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
// This is a trick to set all bytes of a u32 to the same 1 byte value.
uint32_t val32 = (uint32_t) value * 0x01010101;
@@ -3168,7 +3177,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
<< ", " << offset << ", " << size << ")");
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
@@ -3200,7 +3209,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
<< ", " << offset << ", " << size << ")");
wgpu::Device device = buf_ctx->global_ctx->device;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset;
size_t final_size = size;
if (size % 4 != 0) {
@@ -3725,6 +3734,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
static bool ggml_webgpu_supported_qtype(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -3819,6 +3829,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -3857,6 +3868,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -7,8 +7,6 @@ struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
offset_merged_src0: u32,
offset_merged_src1: u32,
stride_src0_0: u32,
stride_src0_1: u32,
@@ -134,8 +132,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x);
let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x);
let src0_i = params.offset_src0 + src0_index(gid.x);
let src1_i = params.offset_src1 + src1_index(gid.x);
update(params.offset_dst + gid.x, src0_i, src1_i);
}
}
@@ -27,6 +27,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
#endif
#ifdef Q1_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18;
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 4u; j++) {
let q_packed = load_u32_at_src(block_byte_base + 2u + j * 4u);
let dst_base128 = dst_base + offset * 128u + j * 32u;
for (var k: u32 = 0; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
for (var bit: u32 = 0; bit < 8u; bit++) {
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
dst[dst_base128 + k * 8u + bit] = w;
}
}
}
}
#endif
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
@@ -61,6 +61,39 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
#endif // INIT_SRC1_SHMEM_FLOAT
#endif
#ifdef INIT_SRC0_SHMEM_Q1_0
const BLOCK_SIZE = 128u;
const BLOCK_SIZE_BYTES = 18u;
const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let tile_m = i / TILE_K;
let tile_k_start = i % TILE_K;
let global_m = offset_m + tile_m;
let global_k_start = k_outer + tile_k_start;
if (global_m >= params.m) {
break;
}
let block_k = global_k_start / BLOCK_SIZE;
let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu;
for (var bit = 0u; bit < NQ; bit++) {
let global_k = global_k_start + bit;
if (global_k < params.k) {
shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q1_0
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 18u;
@@ -128,6 +128,38 @@ fn main(
}
#endif
#ifdef MUL_ACC_Q1_0
#define BLOCK_SIZE 128
#define BLOCK_SIZE_BYTES 18
#define THREADS_PER_BLOCK 16
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
let num_blocks = params.k / BLOCK_SIZE;
let thread_within_block = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
var x_block: array<f32, ELEMS_PER_THREAD>;
for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu;
var row_sum = 0.0;
for (var bit = 0u; bit < 8u; bit++) {
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
row_sum += w * x_block[bit];
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_Q4_0
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 18
@@ -812,6 +844,520 @@ fn main(
}
#endif
#ifdef MUL_ACC_IQ1_S
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 50
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu;
let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_byte = get_byte(qs_w, l);
let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
let gw = iq1_grid[ig / 16u];
let bit_base = (ig % 16u) * 2u;
for (var j = 0u; j < 8u; j++) {
let g = (gw >> (bit_base + j * 2u)) & 3u;
let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ1_M
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 56
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let sc_lo = load_u32_at_src0(block_byte_base + 48u);
let sc_hi = load_u32_at_src0(block_byte_base + 52u);
let sc0 = sc_lo & 0xFFFFu;
let sc1 = (sc_lo >> 16u) & 0xFFFFu;
let sc2 = sc_hi & 0xFFFFu;
let sc3 = (sc_hi >> 16u) & 0xFFFFu;
let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u);
let d = f32(bitcast<vec2<f16>>(d_bits)[0]);
let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u),
select(sc0, sc1, sub_blk >= 2u),
sub_blk < 4u);
let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u);
let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu;
let qh_lo = qh & 0xFFu;
let qh_hi = (qh >> 8u) & 0xFFu;
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u);
let sub_scale = (sc_u16 >> bit_off) & 0x7u;
let dl = d * f32(2u * sub_scale + 1u);
let qh_byte = select(qh_lo, qh_hi, l >= 2u);
let ll2 = l % 2u;
let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u);
let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u);
let ig = grid_idx * 8u;
let gw = iq1_grid[ig / 16u];
let bit_base = (ig % 16u) * 2u;
for (var j = 0u; j < 8u; j++) {
let g = (gw >> (bit_base + j * 2u)) & 3u;
let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ2_XXS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 66
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
let ls = aux_hi >> 28u;
let db = d * (0.5 + f32(ls)) * 0.25;
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let grid_idx = (aux_lo >> (8u * l)) & 0xFFu;
let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu;
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
let gw_lo = iq2xxs_grid[grid_idx * 2u];
let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u];
for (var j = 0u; j < 8u; j++) {
let gw = select(gw_hi, gw_lo, j < 4u);
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
row_sum += db * b * s * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ2_XS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 74
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
let scales_byte = get_byte(scales_word, sub_blk % 4u);
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_word = select(qs_hi, qs_lo, l < 2u);
let half2 = (l % 2u) * 16u;
let qs_val = (qs_word >> half2) & 0xFFFFu;
let grid_idx = qs_val & 0x1FFu;
let signs_idx = (qs_val >> 9u) & 0x7Fu;
let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
let db = d * (0.5 + f32(sub_scale)) * 0.25;
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
let gw_lo = iq2xs_grid[grid_idx * 2u];
let gw_hi = iq2xs_grid[grid_idx * 2u + 1u];
for (var j = 0u; j < 8u; j++) {
let gw = select(gw_hi, gw_lo, j < 4u);
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
row_sum += db * b * s * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ2_S
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 82
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u);
let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
let qh_byte = get_byte(qh_word, sub_blk % 4u);
let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u);
let scales_byte = get_byte(sc_word, sub_blk % 4u);
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_byte = get_byte(qs_w, l);
let sign_byte = get_byte(sg_w, l);
let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u);
let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
let db = d * (0.5 + f32(sub_scale)) * 0.25;
let gw_lo = iq2s_grid[grid_idx * 2u];
let gw_hi = iq2s_grid[grid_idx * 2u + 1u];
for (var j = 0u; j < 8u; j++) {
let gw = select(gw_hi, gw_lo, j < 4u);
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
row_sum += db * b * s * x_block[ll * 8u + j];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ3_XXS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 98
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u);
let ls = aux >> 28u;
let db = d * (0.5 + f32(ls)) * 0.5;
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_word = select(qs_hi, qs_lo, l < 2u);
let byte_pos = (l % 2u) * 2u;
let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
let signs_idx = (aux >> (7u * l)) & 0x7Fu;
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
let grid1 = iq3xxs_grid[grid_idx_0];
let grid2 = iq3xxs_grid[grid_idx_1];
for (var j = 0u; j < 4u; j++) {
let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u);
row_sum += db * b1 * s1 * x_block[ll * 8u + j];
row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ3_S
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 110
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let slot0 = half * 2u;
let y_offset = sub_blk * 32u + slot0 * 8u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
let qh_byte = get_byte(qh_word, sub_blk % 4u);
let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u);
let sc_word = load_u32_at_src0(block_byte_base + 106u);
let scales_byte = get_byte(sc_word, sub_blk / 2u);
let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu;
let db = d * (1.0 + 2.0 * f32(sub_scale));
var row_sum = 0.0;
for (var ll = 0u; ll < 2u; ll++) {
let l = slot0 + ll;
let qs_word = select(qs_hi, qs_lo, l < 2u);
let byte_pos = (l % 2u) * 2u;
let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u);
let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u);
let sign_byte = get_byte(sg_w, l);
let grid1 = iq3s_grid[grid_idx_1];
let grid2 = iq3s_grid[grid_idx_2];
for (var j = 0u; j < 4u; j++) {
let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u);
row_sum += db * b1 * s1 * x_block[ll * 8u + j];
row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
}
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ4_NL
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 18
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
let num_blocks = params.k / BLOCK_SIZE;
let thread_within_block = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u;
var x_block: array<f32, ELEMS_PER_THREAD>;
for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) {
x_block[i] = f32(src1[x_base + i]);
x_block[i + 4u] = f32(src1[x_base + i + 16u]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
var row_sum = 0.0;
let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
let q_byte = get_byte(q_packed, byte_idx);
let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d;
let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d;
row_sum += q_lo * x_block[byte_idx];
row_sum += q_hi * x_block[byte_idx + 4u];
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef MUL_ACC_IQ4_XS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 136
#define THREADS_PER_BLOCK 16
let tid = thread_id % THREADS_PER_BLOCK;
let block_group = thread_id / THREADS_PER_BLOCK;
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
let sub_blk = tid / 2u;
let half = tid % 2u;
let y_offset = sub_blk * 32u + half * 16u;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = block_group; block < num_blocks; block += num_block_groups) {
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
var x_block: array<f32, 16>;
for (var i = 0u; i < 16u; i++) {
x_block[i] = f32(src1[x_base + i]);
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let scales_h = load_u16_at_src0(block_byte_base + 2u);
let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
let sl_byte = get_byte(scales_l_word, sub_blk / 2u);
let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu;
let sh_bits = (scales_h >> (2u * sub_blk)) & 3u;
let ls = i32(sl | (sh_bits << 4u));
let dl = d * f32(ls - 32);
let qs_byte_off = 8u + sub_blk * 16u;
let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off);
let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u);
let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u);
let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u);
var row_sum = 0.0;
for (var i = 0u; i < 16u; i++) {
let q_word = select(
select(q_w0, q_w1, i >= 4u),
select(q_w2, q_w3, i >= 12u),
i >= 8u);
let q_byte = get_byte(q_word, i % 4u);
let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u);
row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i];
}
acc[row] += row_sum;
}
}
}
#endif
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);
@@ -66,8 +66,6 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32)
struct Params {
offset_rn_src: u32,
offset_mul_src: u32,
offset_merged_rn_src: u32,
offset_merged_mul_src: u32,
offset_dst: u32,
stride_rn_src1: u32,
@@ -107,8 +105,8 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1;
let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1;
let i_rn_src_row = params.offset_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1;
let i_mul_src_row = params.offset_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
@@ -45,6 +45,14 @@ struct Params {
};
@group(0) @binding(0) var<storage, read_write> s_in: array<f32>;
#ifdef XBC_OVERLAP
@group(0) @binding(1) var<storage, read_write> x_B_C_merged: array<f32>;
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
@group(0) @binding(4) var<storage, read_write> ids: array<i32>;
@group(0) @binding(5) var<storage, read_write> dst: array<f32>;
@group(0) @binding(6) var<uniform> params: Params;
#else
@group(0) @binding(1) var<storage, read_write> x: array<f32>;
@group(0) @binding(2) var<storage, read_write> dt: array<f32>;
@group(0) @binding(3) var<storage, read_write> A: array<f32>;
@@ -53,6 +61,7 @@ struct Params {
@group(0) @binding(6) var<storage, read_write> ids: array<i32>;
@group(0) @binding(7) var<storage, read_write> dst: array<f32>;
@group(0) @binding(8) var<uniform> params: Params;
#endif
var<workgroup> shared_x_dt: array<f32, TOKENS_PER_TILE>;
var<workgroup> shared_dtsp: array<f32, TOKENS_PER_TILE>;
@@ -98,7 +107,11 @@ fn main(
let dt0 = dt[dt_idx];
let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0);
shared_dtsp[tid] = dtsp;
#ifdef XBC_OVERLAP
shared_x_dt[tid] = x_B_C_merged[x_idx] * dtsp;
#else
shared_x_dt[tid] = x[x_idx] * dtsp;
#endif
}
}
@@ -116,16 +129,28 @@ fn main(
let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3;
let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3;
#ifdef XBC_OVERLAP
let s = s_prev * dA + x_B_C_merged[b_idx] * x_dt;
#else
let s = s_prev * dA + B[b_idx] * x_dt;
#endif
s_prev = s;
#ifdef USE_SUBGROUP_REDUCTION
#ifdef XBC_OVERLAP
let subgroup_partial = subgroupAdd(s * x_B_C_merged[c_idx]);
#else
let subgroup_partial = subgroupAdd(s * C[c_idx]);
#endif
if (subgroup_invocation_id == 0u) {
shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial;
}
#else
#ifdef XBC_OVERLAP
shared_reduce[reduce_idx] = s * x_B_C_merged[c_idx];
#else
shared_reduce[reduce_idx] = s * C[c_idx];
#endif
#endif
workgroupBarrier();
-3
View File
@@ -72,9 +72,6 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
if (model.layers[il].wo_s) {
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
}
cb(cur, "attn_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
-3
View File
@@ -58,9 +58,6 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
if (model.layers[il].wo_s) {
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
}
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-3
View File
@@ -58,9 +58,6 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
if (model.layers[il].wo_s) {
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
}
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+7 -3
View File
@@ -40,8 +40,12 @@ int main(void) {
}
}
// exclude spec args from this check
// ref: https://github.com/ggml-org/llama.cpp/pull/22397
const bool skip = opt.is_spec;
// ensure shorter argument precedes longer argument
if (opt.args.size() > 1) {
if (!skip && opt.args.size() > 1) {
const std::string first(opt.args.front());
const std::string last(opt.args.back());
@@ -124,9 +128,9 @@ int main(void) {
assert(params.n_batch == 9090);
// --draft cannot be used outside llama-speculative
argv = {"binary_name", "--draft", "123"};
argv = {"binary_name", "--spec-draft-n-max", "123"};
assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE));
assert(params.speculative.n_max == 123);
assert(params.speculative.draft.n_max == 123);
// multi-value args (CSV)
argv = {"binary_name", "--lora", "file1.gguf,\"file2,2.gguf\",\"file3\"\"3\"\".gguf\",file4\".gguf"};
+26 -9
View File
@@ -2984,7 +2984,7 @@ struct test_bin_bcast : public test_case {
bool run_whole_graph() override { return nf > 1; }
std::string vars() override {
return VARS_TO_STR5(type, ne, nr, nf, perm1);
return VARS_TO_STR6(type, ne, nr, nf, perm1, src_overlap);
}
size_t op_size(ggml_tensor * t) override {
@@ -3589,9 +3589,10 @@ struct test_ssm_scan : public test_case {
const int64_t n_group;
const int64_t n_seq_tokens;
const int64_t n_seqs;
const bool xbc_overlap;
std::string vars() override {
return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);
return VARS_TO_STR8(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs, xbc_overlap);
}
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
@@ -3600,16 +3601,31 @@ struct test_ssm_scan : public test_case {
int64_t n_head = 32,
int64_t n_group = 1,
int64_t n_seq_tokens = 32,
int64_t n_seqs = 32)
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
int64_t n_seqs = 32,
bool xbc_overlap = false)
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), xbc_overlap(xbc_overlap) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs);
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * x;
ggml_tensor * B;
ggml_tensor * C;
if (xbc_overlap) {
ggml_tensor * xbc = ggml_new_tensor_4d(ctx, type, d_state, n_head, n_seq_tokens, 2 * n_seqs);
x = ggml_view_4d(ctx, xbc, head_dim, n_head, n_seq_tokens, n_seqs,
xbc->nb[1], xbc->nb[2], xbc->nb[3], xbc->nb[3]);
B = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs,
xbc->nb[1], xbc->nb[2], xbc->nb[3], 0);
C = ggml_view_4d(ctx, xbc, d_state, n_group, n_seq_tokens, n_seqs,
xbc->nb[1], xbc->nb[2], xbc->nb[3], 2 * xbc->nb[3]);
} else {
x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
}
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
return out;
@@ -3799,7 +3815,7 @@ struct test_mul_mat : public test_case {
double max_nmse_err(ggml_backend_t backend) override {
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
if ((type_a == GGML_TYPE_MXFP4 || type_a == GGML_TYPE_NVFP4) && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
return 2e-2;
}
return max_nmse_err();
@@ -3935,7 +3951,7 @@ struct test_mul_mat_id : public test_case {
double max_nmse_err(ggml_backend_t backend) override {
// for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
if ((type_a == GGML_TYPE_MXFP4 || type_a == GGML_TYPE_NVFP4) && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
return 2e-2;
}
return max_nmse_err();
@@ -7964,6 +7980,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 128, 4, 4, 16, 2, true)); // x/B/C overlap
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
+40
View File
@@ -2249,6 +2249,46 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.expect(message_assist)
.run();
{
// additional tests for https://github.com/ggml-org/llama.cpp/pull/21760
auto tmpls = read_templates("models/templates/google-gemma-4-31B-it.jinja");
common_chat_msg tool_call_msg = simple_assist_msg(
"Let me check.", "", "special_function", "{\"arg1\": 1}","c0");
common_chat_msg tool_msg;
tool_msg.role = "tool";
tool_msg.tool_name = "special_function";
tool_msg.tool_call_id = "c0";
tool_msg.content = "{\"r\":\"ok\"}";
{
common_chat_templates_inputs inputs;
inputs.messages = { message_user, tool_call_msg, tool_msg };
inputs.tools = { special_function_tool };
inputs.add_generation_prompt = true;
auto params = common_chat_templates_apply(tmpls.get(), inputs);
if (!string_ends_with(params.prompt, "<turn|>\n<|turn>model\n")) {
throw std::runtime_error("Missing generation prompt for Gemma 4");
}
}
{
common_chat_templates_inputs inputs;
inputs.messages = { message_user, tool_call_msg, tool_msg };
inputs.tools = { special_function_tool };
inputs.add_generation_prompt = false;
auto params = common_chat_templates_apply(tmpls.get(), inputs);
if (string_ends_with(params.prompt, "<|turn>model\n")) {
throw std::runtime_error("Gemma 4: generation prompt was modified despite add_generation_prompt=false");
}
}
}
}
{
+4
View File
@@ -35,5 +35,9 @@ int main() {
threads[i].join();
}
common_log_flush(common_log_main());
// We explicitly free the logger singleton to avoid hanging on Windows
// related to timing issues of thread startup and DLL teardown
common_log_free(common_log_main());
return 0;
}
+24 -1
View File
@@ -227,7 +227,30 @@ int main(void) {
3); // forcing continues through i=3
}
printf("OK (5 tests passed)\n");
// Test 6: Multi-block thinking. First block ends naturally at i=2, second
// start tag at i=3 re-arms the budget, which then exhausts at i=5.
// Regression: before this fix, DONE absorbed all subsequent tokens and a
// second <think> block ran unbudgeted.
// Flow: i=0 accept(100)->COUNTING rem=2; i=1 accept(50)->rem=1;
// i=2 accept(101)->end_matcher matches, DONE;
// i=3 accept(100)->re-arm, COUNTING rem=2;
// i=4 accept(60)->rem=1; i=5 accept(61)->rem=0->FORCING;
// i=6 apply()->forces token[0]=102, accept(62)->force_pos=1, stay FORCING;
// i=7 apply()->forces token[1]=101, accept(63)->force_pos=2->DONE.
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
const std::vector<llama_token> sequence = {100, 50, 101, 100, 60, 61, 62, 63};
test_reasoning_budget("multi-block re-arms budget after DONE", sequence, start, end, forced,
2, // budget of 2 tokens (per block)
REASONING_BUDGET_IDLE,
6, // forcing starts at i=6 (after second block exhausts at i=5)
7); // forcing continues through i=7
}
printf("OK (6 tests passed)\n");
printf("Testing UTF-8 boundary detection... ");
test_utf8_boundary_detection();
+1 -1
View File
@@ -372,7 +372,7 @@ static const cmd_params cmd_params_defaults = {
/* n_ubatch */ { 512 },
/* type_k */ { GGML_TYPE_F16 },
/* type_v */ { GGML_TYPE_F16 },
/* n_threads */ { cpu_get_num_math() },
/* n_threads */ { common_cpu_get_num_math() },
/* cpu_mask */ { "0x0" },
/* cpu_strict */ { false },
/* poll */ { 50 },
+1 -1
View File
@@ -317,7 +317,7 @@ int main(int argc, char * argv[]) {
const char * cache_dir = nullptr;
std::string cache_dir_str;
if (params.use_cache) {
cache_dir_str = fs_get_cache_directory() + "rpc/";
cache_dir_str = fs_get_cache_directory() + "rpc" + DIRECTORY_SEPARATOR;
if (!fs_create_directory_with_parents(cache_dir_str)) {
fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
return 1;
File diff suppressed because one or more lines are too long
+5618 -5478
View File
File diff suppressed because it is too large Load Diff
File diff suppressed because one or more lines are too long
+2 -2
View File
@@ -575,14 +575,14 @@ json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
json convert_transcriptions_to_chatcmpl(
const json & inp_body,
const common_chat_templates * tmpls,
const std::map<std::string, raw_buffer> & in_files,
const std::map<std::string, uploaded_file> & in_files,
std::vector<raw_buffer> & out_files) {
// TODO @ngxson : this function may need to be improved in the future
// handle input files
out_files.clear();
auto it = in_files.find("file");
if (it != in_files.end()) {
out_files.push_back(it->second);
out_files.push_back(it->second.data);
} else {
throw std::invalid_argument("No input file found for transcription");
}
+2 -1
View File
@@ -4,6 +4,7 @@
#include "chat.h"
#include "server-common.h"
#include "server-http.h"
#include <nlohmann/json_fwd.hpp>
@@ -19,7 +20,7 @@ json server_chat_convert_anthropic_to_oai(const json & body);
json convert_transcriptions_to_chatcmpl(
const json & body,
const common_chat_templates * tmpls,
const std::map<std::string, raw_buffer> & in_files,
const std::map<std::string, uploaded_file> & in_files,
std::vector<raw_buffer> & out_files);
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
+12 -15
View File
@@ -309,8 +309,10 @@ struct server_slot {
return 0;
}
const int n_draft_min = common_speculative_n_min(spec.get(), task->params.speculative);
// determine the max draft that fits the current slot state
int n_draft_max = task->params.speculative.n_max;
int n_draft_max = common_speculative_n_max(spec.get(), task->params.speculative);
// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
@@ -322,8 +324,8 @@ struct server_slot {
SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);
if (n_draft_max < task->params.speculative.n_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
if (n_draft_max < n_draft_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, n_draft_min);
n_draft_max = 0;
}
@@ -358,11 +360,6 @@ struct server_slot {
spec_draft.resize(n_draft_max);
}
if (spec_draft.size() < (size_t) params_spec.n_min) {
SLT_DBG(*this, "ignoring small draft: %d < %d\n", (int) spec_draft.size(), params_spec.n_min);
spec_draft.clear();
}
if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) {
const auto n_tokens = prompt.tokens.size();
@@ -770,9 +767,9 @@ private:
if (params_base.speculative.has_dft()) {
// TODO speculative: move to common/speculative.cpp?
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
const auto & params_spec = params_base.speculative.draft;
const auto & params_spec = params_base.speculative;
SRV_INF("loading draft model '%s'\n", params_spec.mparams.path.c_str());
auto params_dft = params_base;
@@ -780,7 +777,7 @@ private:
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams_dft;
params_dft.model = params_spec.mparams;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
params_dft.cache_type_k = params_spec.cache_type_k;
params_dft.cache_type_v = params_spec.cache_type_v;
@@ -800,8 +797,8 @@ private:
return false;
}
params_base.speculative.model_dft = model_dft.get();
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
params_base.speculative.draft.model = model_dft.get();
params_base.speculative.draft.cparams = common_context_params_to_llama(params_dft);
}
std::string & mmproj_path = params_base.mmproj.path;
@@ -1310,7 +1307,7 @@ private:
backend_sampling &= task.params.sampling.backend_sampling;
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.can_speculate() && task.params.speculative.n_max > 0);
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
@@ -3031,7 +3028,7 @@ private:
slot.sampled = ids.back(); // last accepted token
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.n_tokens(), -1);
llama_memory_seq_rm(llama_get_memory(slot.ctx), slot.id, slot.prompt.tokens.pos_next(), -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
+1
View File
@@ -49,6 +49,7 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
parsed_url.path,
headers,
req.body,
req.files,
req.should_stop,
600, // timeout_read (default to 10 minutes)
600 // timeout_write (default to 10 minutes)
+6 -2
View File
@@ -438,7 +438,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
std::string body = req.body;
std::map<std::string, raw_buffer> files;
std::map<std::string, uploaded_file> files;
if (req.is_multipart_form_data()) {
// translate text fields to a JSON object and use it as the body
@@ -459,7 +459,11 @@ void server_http_context::post(const std::string & path, const server_http_conte
// populate files from multipart form
for (const auto & [key, file] : req.form.files) {
files[key] = raw_buffer(file.content.begin(), file.content.end());
files[key] = uploaded_file{
raw_buffer(file.content.begin(), file.content.end()),
file.filename,
file.content_type,
};
}
}
+7 -1
View File
@@ -36,13 +36,19 @@ struct server_http_res {
using server_http_res_ptr = std::unique_ptr<server_http_res>;
using raw_buffer = std::vector<uint8_t>;
struct uploaded_file {
raw_buffer data;
std::string filename;
std::string content_type;
};
struct server_http_req {
std::map<std::string, std::string> params; // path_params + query_params
std::map<std::string, std::string> headers; // used by MCP proxy
std::string path;
std::string query_string; // query parameters string (e.g. "action=save")
std::string body;
std::map<std::string, raw_buffer> files; // used for file uploads (form data)
std::map<std::string, uploaded_file> files; // used for file uploads (form data)
const std::function<bool()> & should_stop;
std::string get_param(const std::string & key, const std::string & def = "") const {
+116 -4
View File
@@ -18,6 +18,8 @@
#include <chrono>
#include <queue>
#include <filesystem>
#include <random>
#include <sstream>
#include <cstring>
#ifdef _WIN32
@@ -823,6 +825,7 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
proxy_path,
req.headers,
req.body,
req.files,
req.should_stop,
base_params.timeout_read,
base_params.timeout_write
@@ -1126,6 +1129,77 @@ static bool should_strip_proxy_header(const std::string & header_name) {
return false;
}
static std::string generate_multipart_boundary() {
thread_local std::mt19937 gen(std::random_device{}());
static const char chars[] = "0123456789abcdefghijklmnopqrstuvwxyz";
std::uniform_int_distribution<> dis(0, sizeof(chars) - 2);
std::string boundary = "----llama-cpp-proxy-";
for (int i = 0; i < 16; i++) {
boundary += chars[dis(gen)];
}
return boundary;
}
static std::string build_multipart_body(
const json & form_fields,
const std::map<std::string, uploaded_file> & files,
const std::string & boundary) {
static auto sanitize_field = [](const std::string & text) {
std::string result;
result.reserve(text.size());
for (char c : text) {
if (c != '\n' && c != '\r' && c != '"') {
result += c;
}
}
return result;
};
std::ostringstream body;
for (const auto & [key, value] : form_fields.items()) {
if (value.is_array()) {
for (const auto & item : value) {
body << "--" << boundary << "\r\n";
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
body << "\r\n";
if (!item.is_string()) {
throw std::invalid_argument("expected string");
}
body << item.get<std::string>() << "\r\n";
}
} else {
body << "--" << boundary << "\r\n";
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
body << "\r\n";
if (!value.is_string()) {
throw std::invalid_argument("expected string");
}
body << value.get<std::string>() << "\r\n";
}
}
for (const auto & [key, file] : files) {
body << "--" << boundary << "\r\n";
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"";
if (!file.filename.empty()) {
body << "; filename=\"" << sanitize_field(file.filename) << "\"";
}
body << "\r\n";
if (!file.content_type.empty()) {
body << "Content-Type: " << sanitize_field(file.content_type) << "\r\n";
} else {
body << "Content-Type: application/octet-stream\r\n";
}
body << "\r\n";
body.write(reinterpret_cast<const char*>(file.data.data()), file.data.size());
body << "\r\n";
}
body << "--" << boundary << "--\r\n";
return body.str();
}
server_http_proxy::server_http_proxy(
const std::string & method,
const std::string & scheme,
@@ -1134,6 +1208,7 @@ server_http_proxy::server_http_proxy(
const std::string & path,
const std::map<std::string, std::string> & headers,
const std::string & body,
const std::map<std::string, uploaded_file> & files,
const std::function<bool()> should_stop,
int32_t timeout_read,
int32_t timeout_write
@@ -1195,28 +1270,65 @@ server_http_proxy::server_http_proxy(
return pipe->write({{}, 0, std::string(data, data_length), ""});
};
// when files are present, the body was converted from multipart form data to JSON
// we need to reconstruct the multipart body for the downstream server
std::string effective_body = body;
std::string override_content_type;
bool has_files = !files.empty();
if (has_files) {
json form_fields = json::parse(body, nullptr, false);
if (!form_fields.is_discarded()) {
auto boundary = generate_multipart_boundary();
effective_body = build_multipart_body(form_fields, files, boundary);
override_content_type = "multipart/form-data; boundary=" + boundary;
} else {
throw std::runtime_error("failed to parse multipart form fields JSON");
}
}
// prepare the request to destination server
httplib::Request req;
{
req.method = method;
req.path = path;
for (const auto & [key, value] : headers) {
if (key == "Accept-Encoding") {
const auto lowered = to_lower_copy(key);
if (lowered == "accept-encoding") {
// disable Accept-Encoding to avoid compressed responses
continue;
}
if (key == "Transfer-Encoding") {
if (lowered == "transfer-encoding") {
// the body is already decoded
continue;
}
if (key == "Host" || key == "host") {
if (lowered == "content-length") {
// let httplib calculate Content-Length from the actual body
continue;
}
if (lowered == "content-type") {
if (has_files) {
// we set our own Content-Type with the new boundary
continue;
}
// when no files but the original request was multipart,
// the body is now JSON, so correct the Content-Type
if (value.find("multipart/form-data") != std::string::npos) {
override_content_type = "application/json; charset=utf-8";
continue;
}
}
if (lowered == "host") {
bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
} else {
req.set_header(key, value);
}
}
req.body = body;
req.body = effective_body;
if (!override_content_type.empty()) {
req.set_header("Content-Type", override_content_type);
}
req.response_handler = response_handler;
req.content_receiver = content_receiver;
}
+1
View File
@@ -202,6 +202,7 @@ public:
const std::string & path,
const std::map<std::string, std::string> & headers,
const std::string & body,
const std::map<std::string, uploaded_file> & files,
const std::function<bool()> should_stop,
int32_t timeout_read,
int32_t timeout_write
+10 -18
View File
@@ -76,13 +76,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.ngram_size_n", speculative.ngram_size_n},
{"speculative.ngram_size_m", speculative.ngram_size_m},
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -139,13 +133,7 @@ json task_params::to_json(bool only_metrics) const {
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"generation_prompt", chat_parser_params.generation_prompt},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.ngram_size_n", speculative.ngram_size_n},
{"speculative.ngram_size_m", speculative.ngram_size_m},
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -308,14 +296,17 @@ task_params server_task::params_from_json_cmpl(
params.speculative = defaults.speculative;
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
// TODO: for now, be able to adjust only the draft-model based speculative parameters
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
params.speculative.draft.p_min = json_value(data, "speculative.p_min", defaults.speculative.draft.p_min);
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
params.speculative.n_min = std::max(params.speculative.n_min, 0);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
params.speculative.draft.n_min = std::min(params.speculative.draft.n_max, params.speculative.draft.n_min);
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
#if 0
// for debugging and research purposes
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
@@ -325,6 +316,7 @@ task_params server_task::params_from_json_cmpl(
params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024);
params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024);
params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024);
#endif
// Use OpenAI API logprobs only if n_probs wasn't provided
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
+6 -9
View File
@@ -83,15 +83,14 @@ class ServerProcess:
kv_unified: bool | None = False
server_slots: bool | None = False
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
models_dir: str | None = None
models_max: int | None = None
no_models_autoload: bool | None = None
lora_files: List[str] | None = None
enable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
spec_draft_n_min: int | None = None
spec_draft_n_max: int | None = None
no_webui: bool | None = None
jinja: bool | None = None
reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
@@ -165,8 +164,6 @@ class ServerProcess:
server_args.extend(["--threads", self.n_threads])
if self.n_gpu_layer:
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
if self.draft is not None:
server_args.extend(["--draft", self.draft])
if self.server_continuous_batching:
server_args.append("--cont-batching")
if self.server_embeddings:
@@ -214,10 +211,10 @@ class ServerProcess:
server_args.append("--context-shift")
if self.api_key:
server_args.extend(["--api-key", self.api_key])
if self.draft_max:
server_args.extend(["--draft-max", self.draft_max])
if self.draft_min:
server_args.extend(["--draft-min", self.draft_min])
if self.spec_draft_n_max:
server_args.extend(["--spec-draft-n-max", self.spec_draft_n_max])
if self.spec_draft_n_min:
server_args.extend(["--spec-draft-n-min", self.spec_draft_n_min])
if self.no_webui:
server_args.append("--no-webui")
if self.no_models_autoload:
+392 -500
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -29,7 +29,7 @@
--chart-3: oklch(0.398 0.07 227.392);
--chart-4: oklch(0.828 0.189 84.429);
--chart-5: oklch(0.769 0.188 70.08);
--sidebar: oklch(0.987 0 0);
--sidebar: oklch(0.985 0 0);
--sidebar-foreground: oklch(0.145 0 0);
--sidebar-primary: oklch(0.205 0 0);
--sidebar-primary-foreground: oklch(0.985 0 0);
@@ -77,7 +77,7 @@
--chart-3: oklch(0.769 0.188 70.08);
--chart-4: oklch(0.627 0.265 303.9);
--chart-5: oklch(0.645 0.246 16.439);
--sidebar: oklch(0.19 0 0);
--sidebar: oklch(0.2 0 0);
--sidebar-foreground: oklch(0.985 0 0);
--sidebar-primary: oklch(0.488 0.243 264.376);
--sidebar-primary-foreground: oklch(0.985 0 0);
@@ -1,18 +1,20 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button';
import { Button, type ButtonVariant, type ButtonSize } from '$lib/components/ui/button';
import * as Tooltip from '$lib/components/ui/tooltip';
import type { Component } from 'svelte';
import { TooltipSide } from '$lib/enums';
interface Props {
icon: Component;
tooltip: string;
variant?: 'default' | 'destructive' | 'outline' | 'secondary' | 'ghost' | 'link';
size?: 'default' | 'sm' | 'lg' | 'icon';
variant?: ButtonVariant;
size?: ButtonSize;
iconSize?: string;
class?: string;
disabled?: boolean;
onclick: (e?: MouseEvent) => void;
'aria-label'?: string;
tooltipSide?: TooltipSide;
}
let {
@@ -23,6 +25,7 @@
class: className = '',
disabled = false,
iconSize = 'h-3 w-3',
tooltipSide = TooltipSide.TOP,
onclick,
'aria-label': ariaLabel
}: Props = $props();
@@ -35,7 +38,7 @@
{size}
{disabled}
{onclick}
class="h-6 w-6 p-0 {className} flex"
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
aria-label={ariaLabel || tooltip}
>
{@const IconComponent = icon}
@@ -44,7 +47,7 @@
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<Tooltip.Content side={tooltipSide}>
<p>{tooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
@@ -300,7 +300,7 @@
if (sendOnEnter || isModifier) {
event.preventDefault();
if (!canSubmit || disabled || isLoading || hasLoadingAttachments) return;
if (!canSubmit || disabled || hasLoadingAttachments) return;
onSubmit?.();
}
@@ -555,7 +555,7 @@
class="relative {className}"
onsubmit={(e) => {
e.preventDefault();
if (!canSubmit || disabled || isLoading || hasLoadingAttachments) return;
if (!canSubmit || disabled || hasLoadingAttachments) return;
onSubmit?.();
}}
>
@@ -1,333 +0,0 @@
<script lang="ts">
import { page } from '$app/state';
import { Plus, MessageSquare, Settings, Zap, FolderOpen } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import { Switch } from '$lib/components/ui/switch';
import { FILE_TYPE_ICONS, TOOLTIP_DELAY_DURATION } from '$lib/constants';
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { HealthCheckStatus } from '$lib/enums';
import type { MCPServerSettingsEntry } from '$lib/types';
interface Props {
class?: string;
disabled?: boolean;
hasAudioModality?: boolean;
hasVisionModality?: boolean;
hasMcpPromptsSupport?: boolean;
hasMcpResourcesSupport?: boolean;
onFileUpload?: () => void;
onSystemPromptClick?: () => void;
onMcpPromptClick?: () => void;
onMcpSettingsClick?: () => void;
onMcpResourcesClick?: () => void;
}
let {
class: className = '',
disabled = false,
hasAudioModality = false,
hasVisionModality = false,
hasMcpPromptsSupport = false,
hasMcpResourcesSupport = false,
onFileUpload,
onSystemPromptClick,
onMcpPromptClick,
onMcpSettingsClick,
onMcpResourcesClick
}: Props = $props();
let isNewChat = $derived(!page.params.id);
let systemMessageTooltip = $derived(
isNewChat
? 'Add custom system message for a new conversation'
: 'Inject custom system message at the beginning of the conversation'
);
let dropdownOpen = $state(false);
let mcpServers = $derived(mcpStore.getServersSorted().filter((s) => s.enabled));
let hasMcpServers = $derived(mcpServers.length > 0);
let mcpSearchQuery = $state('');
let filteredMcpServers = $derived.by(() => {
const query = mcpSearchQuery.toLowerCase().trim();
if (!query) return mcpServers;
return mcpServers.filter((s) => {
const name = getServerLabel(s).toLowerCase();
const url = s.url.toLowerCase();
return name.includes(query) || url.includes(query);
});
});
function getServerLabel(server: MCPServerSettingsEntry): string {
return mcpStore.getServerLabel(server);
}
function isServerEnabledForChat(serverId: string): boolean {
return conversationsStore.isMcpServerEnabledForChat(serverId);
}
async function toggleServerForChat(serverId: string) {
await conversationsStore.toggleMcpServerForChat(serverId);
}
function handleMcpSubMenuOpen(open: boolean) {
if (open) {
mcpSearchQuery = '';
mcpStore.runHealthChecksForServers(mcpServers);
}
}
function handleMcpPromptClick() {
dropdownOpen = false;
onMcpPromptClick?.();
}
function handleMcpSettingsClick() {
dropdownOpen = false;
onMcpSettingsClick?.();
}
function handleMcpResourcesClick() {
dropdownOpen = false;
onMcpResourcesClick?.();
}
const fileUploadTooltipText = 'Add files, system prompt or MCP Servers';
</script>
<div class="flex items-center gap-1 {className}">
<DropdownMenu.Root bind:open={dropdownOpen}>
<DropdownMenu.Trigger name="Attach files" {disabled}>
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
variant="secondary"
type="button"
>
<span class="sr-only">{fileUploadTooltipText}</span>
<Plus class="h-4 w-4" />
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{fileUploadTooltipText}</p>
</Tooltip.Content>
</Tooltip.Root>
</DropdownMenu.Trigger>
<DropdownMenu.Content align="start" class="w-48">
{#if hasVisionModality}
<DropdownMenu.Item
class="images-button flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.image class="h-4 w-4" />
<span>Images</span>
</DropdownMenu.Item>
{:else}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="images-button flex cursor-pointer items-center gap-2"
disabled
>
<FILE_TYPE_ICONS.image class="h-4 w-4" />
<span>Images</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>Image processing requires a vision model</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
{#if hasAudioModality}
<DropdownMenu.Item
class="audio-button flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
<span>Audio Files</span>
</DropdownMenu.Item>
{:else}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item class="audio-button flex cursor-pointer items-center gap-2" disabled>
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
<span>Audio Files</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>Audio files processing requires an audio model</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.text class="h-4 w-4" />
<span>Text Files</span>
</DropdownMenu.Item>
{#if hasVisionModality}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
<span>PDF Files</span>
</DropdownMenu.Item>
{:else}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onFileUpload?.()}
>
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
<span>PDF Files</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>PDFs will be converted to text. Image-based PDFs may not work properly.</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => onSystemPromptClick?.()}
>
<MessageSquare class="h-4 w-4" />
<span>System Message</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>{systemMessageTooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
<DropdownMenu.Separator />
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
<McpLogo class="h-4 w-4" />
<span>MCP Servers</span>
</DropdownMenu.SubTrigger>
<DropdownMenu.SubContent class="w-72 pt-0">
<DropdownMenuSearchable
placeholder="Search servers..."
bind:searchValue={mcpSearchQuery}
emptyMessage={hasMcpServers ? 'No servers found' : 'No MCP servers configured'}
isEmpty={filteredMcpServers.length === 0}
>
<div class="max-h-64 overflow-y-auto">
{#each filteredMcpServers as server (server.id)}
{@const healthState = mcpStore.getHealthCheckState(server.id)}
{@const hasError = healthState.status === HealthCheckStatus.ERROR}
{@const isEnabledForChat = isServerEnabledForChat(server.id)}
<button
type="button"
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
onclick={() => !hasError && toggleServerForChat(server.id)}
disabled={hasError}
>
<div class="flex min-w-0 flex-1 items-center gap-2">
{#if mcpStore.getServerFavicon(server.id)}
<img
src={mcpStore.getServerFavicon(server.id)}
alt=""
class="h-4 w-4 shrink-0 rounded-sm"
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
/>
{/if}
<span class="truncate text-sm">{getServerLabel(server)}</span>
{#if hasError}
<span
class="shrink-0 rounded bg-destructive/15 px-1.5 py-0.5 text-xs text-destructive"
>
Error
</span>
{/if}
</div>
<Switch
checked={isEnabledForChat}
disabled={hasError}
onclick={(e: MouseEvent) => e.stopPropagation()}
onCheckedChange={() => toggleServerForChat(server.id)}
/>
</button>
{/each}
</div>
{#snippet footer()}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Settings class="h-4 w-4" />
<span>Manage MCP Servers</span>
</DropdownMenu.Item>
{/snippet}
</DropdownMenuSearchable>
</DropdownMenu.SubContent>
</DropdownMenu.Sub>
{#if hasMcpPromptsSupport}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpPromptClick}
>
<Zap class="h-4 w-4" />
<span>MCP Prompt</span>
</DropdownMenu.Item>
{/if}
{#if hasMcpResourcesSupport}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpResourcesClick}
>
<FolderOpen class="h-4 w-4" />
<span>MCP Resources</span>
</DropdownMenu.Item>
{/if}
</DropdownMenu.Content>
</DropdownMenu.Root>
</div>
@@ -1,170 +0,0 @@
<script lang="ts">
import { Plus, MessageSquare, Zap, FolderOpen } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as Sheet from '$lib/components/ui/sheet';
import { FILE_TYPE_ICONS } from '$lib/constants';
import { McpLogo } from '$lib/components/app';
interface Props {
class?: string;
disabled?: boolean;
hasAudioModality?: boolean;
hasVisionModality?: boolean;
hasMcpPromptsSupport?: boolean;
hasMcpResourcesSupport?: boolean;
onFileUpload?: () => void;
onSystemPromptClick?: () => void;
onMcpPromptClick?: () => void;
onMcpSettingsClick?: () => void;
onMcpResourcesClick?: () => void;
}
let {
class: className = '',
disabled = false,
hasAudioModality = false,
hasVisionModality = false,
hasMcpPromptsSupport = false,
hasMcpResourcesSupport = false,
onFileUpload,
onSystemPromptClick,
onMcpPromptClick,
onMcpSettingsClick,
onMcpResourcesClick
}: Props = $props();
let sheetOpen = $state(false);
function handleMcpPromptClick() {
sheetOpen = false;
onMcpPromptClick?.();
}
function handleMcpSettingsClick() {
onMcpSettingsClick?.();
}
function handleMcpResourcesClick() {
sheetOpen = false;
onMcpResourcesClick?.();
}
function handleSheetFileUpload() {
sheetOpen = false;
onFileUpload?.();
}
function handleSheetSystemPromptClick() {
sheetOpen = false;
onSystemPromptClick?.();
}
const fileUploadTooltipText = 'Add files, system prompt or MCP Servers';
const sheetItemClass =
'flex w-full items-center gap-3 rounded-md px-3 py-2.5 text-left text-sm transition-colors hover:bg-accent active:bg-accent disabled:cursor-not-allowed disabled:opacity-50';
</script>
<div class="flex items-center gap-1 {className}">
<Sheet.Root bind:open={sheetOpen}>
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
variant="secondary"
type="button"
onclick={() => (sheetOpen = true)}
>
<span class="sr-only">{fileUploadTooltipText}</span>
<Plus class="h-4 w-4" />
</Button>
<Sheet.Content side="bottom" class="max-h-[85vh] gap-0">
<Sheet.Header>
<Sheet.Title>Add to chat</Sheet.Title>
<Sheet.Description class="sr-only">
Add files, system prompt or configure MCP servers
</Sheet.Description>
</Sheet.Header>
<div class="flex flex-col gap-1 overflow-y-auto px-1.5 pb-2">
<!-- Images -->
<button
type="button"
class={sheetItemClass}
disabled={!hasVisionModality}
onclick={handleSheetFileUpload}
>
<FILE_TYPE_ICONS.image class="h-4 w-4 shrink-0" />
<span>Images</span>
{#if !hasVisionModality}
<span class="ml-auto text-xs text-muted-foreground">Requires vision model</span>
{/if}
</button>
<!-- Audio -->
<button
type="button"
class={sheetItemClass}
disabled={!hasAudioModality}
onclick={handleSheetFileUpload}
>
<FILE_TYPE_ICONS.audio class="h-4 w-4 shrink-0" />
<span>Audio Files</span>
{#if !hasAudioModality}
<span class="ml-auto text-xs text-muted-foreground">Requires audio model</span>
{/if}
</button>
<button type="button" class={sheetItemClass} onclick={handleSheetFileUpload}>
<FILE_TYPE_ICONS.text class="h-4 w-4 shrink-0" />
<span>Text Files</span>
</button>
<button type="button" class={sheetItemClass} onclick={handleSheetFileUpload}>
<FILE_TYPE_ICONS.pdf class="h-4 w-4 shrink-0" />
<span>PDF Files</span>
{#if !hasVisionModality}
<span class="ml-auto text-xs text-muted-foreground">Text-only</span>
{/if}
</button>
<button type="button" class={sheetItemClass} onclick={handleSheetSystemPromptClick}>
<MessageSquare class="h-4 w-4 shrink-0" />
<span>System Message</span>
</button>
<button type="button" class={sheetItemClass} onclick={handleMcpSettingsClick}>
<McpLogo class="h-4 w-4 shrink-0" />
<span>MCP Servers</span>
</button>
{#if hasMcpPromptsSupport}
<button type="button" class={sheetItemClass} onclick={handleMcpPromptClick}>
<Zap class="h-4 w-4 shrink-0" />
<span>MCP Prompt</span>
</button>
{/if}
{#if hasMcpResourcesSupport}
<button type="button" class={sheetItemClass} onclick={handleMcpResourcesClick}>
<FolderOpen class="h-4 w-4 shrink-0" />
<span>MCP Resources</span>
</button>
{/if}
</div>
</Sheet.Content>
</Sheet.Root>
</div>
@@ -1,62 +1,39 @@
<script lang="ts">
import { Settings } from '@lucide/svelte';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { Settings, Plus } from '@lucide/svelte';
import { Switch } from '$lib/components/ui/switch';
import { DropdownMenuSearchable, McpActiveServersAvatars } from '$lib/components/app';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { McpLogo, DropdownMenuSearchable } from '$lib/components/app';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { HealthCheckStatus } from '$lib/enums';
import type { MCPServerSettingsEntry } from '$lib/types';
import { goto } from '$app/navigation';
interface Props {
class?: string;
disabled?: boolean;
onSettingsClick?: () => void;
onMcpSettingsClick?: () => void;
}
let { class: className = '', disabled = false, onSettingsClick }: Props = $props();
let { onMcpSettingsClick }: Props = $props();
let searchQuery = $state('');
let mcpServers = $derived(mcpStore.getServersSorted().filter((s) => s.enabled));
let mcpSearchQuery = $state('');
let allMcpServers = $derived(mcpStore.getServersSorted());
let mcpServers = $derived(allMcpServers.filter((s) => s.enabled));
let hasMcpServers = $derived(mcpServers.length > 0);
let enabledMcpServersForChat = $derived(
mcpServers.filter((s) => conversationsStore.isMcpServerEnabledForChat(s.id) && s.url.trim())
);
let healthyEnabledMcpServers = $derived(
enabledMcpServersForChat.filter((s) => {
const healthState = mcpStore.getHealthCheckState(s.id);
return healthState.status !== HealthCheckStatus.ERROR;
})
);
let hasEnabledMcpServers = $derived(enabledMcpServersForChat.length > 0);
let mcpFavicons = $derived(
healthyEnabledMcpServers
.slice(0, 3)
.map((s) => ({ id: s.id, url: mcpStore.getServerFavicon(s.id) }))
.filter((f) => f.url !== null)
);
// let hasAnyMcpServers = $derived(allMcpServers.length > 0);
let filteredMcpServers = $derived.by(() => {
const query = searchQuery.toLowerCase().trim();
if (query) {
return mcpServers.filter((s) => {
const name = getServerLabel(s).toLowerCase();
const url = s.url.toLowerCase();
return name.includes(query) || url.includes(query);
});
}
return mcpServers;
const query = mcpSearchQuery.toLowerCase().trim();
if (!query) return mcpServers;
return mcpServers.filter((s) => {
const name = getServerLabel(s).toLowerCase();
const url = s.url.toLowerCase();
return name.includes(query) || url.includes(query);
});
});
function getServerLabel(server: MCPServerSettingsEntry): string {
return mcpStore.getServerLabel(server);
}
function handleDropdownOpen(open: boolean) {
if (open) {
mcpStore.runHealthChecksForServers(mcpServers);
}
}
function isServerEnabledForChat(serverId: string): boolean {
return conversationsStore.isMcpServerEnabledForChat(serverId);
}
@@ -64,38 +41,33 @@
async function toggleServerForChat(serverId: string) {
await conversationsStore.toggleMcpServerForChat(serverId);
}
function handleMcpSubMenuOpen(open: boolean) {
if (open) {
mcpSearchQuery = '';
mcpStore.runHealthChecksForServers(allMcpServers);
}
}
function handleMcpSettingsClick() {
onMcpSettingsClick?.();
goto(`${hasMcpServers ? '' : '?add'}#/settings/mcp`);
}
</script>
{#if hasMcpServers && hasEnabledMcpServers && mcpFavicons.length > 0}
<DropdownMenu.Root
onOpenChange={(open) => {
if (!open) {
searchQuery = '';
}
handleDropdownOpen(open);
}}
>
<DropdownMenu.Trigger
{disabled}
onclick={(e) => {
e.preventDefault();
e.stopPropagation();
}}
>
<button
type="button"
class="inline-flex cursor-pointer items-center rounded-sm py-1 disabled:cursor-not-allowed disabled:opacity-60"
{disabled}
aria-label="MCP Servers"
>
<McpActiveServersAvatars class={className} />
</button>
</DropdownMenu.Trigger>
<DropdownMenu.Sub onOpenChange={handleMcpSubMenuOpen}>
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
<McpLogo class="h-4 w-4" />
<DropdownMenu.Content align="start" class="w-72 pt-0">
<span>MCP Servers</span>
</DropdownMenu.SubTrigger>
<DropdownMenu.SubContent class="w-72 pt-0">
{#if hasMcpServers}
<DropdownMenuSearchable
bind:searchValue={searchQuery}
placeholder="Search servers..."
bind:searchValue={mcpSearchQuery}
emptyMessage="No servers found"
isEmpty={filteredMcpServers.length === 0}
>
@@ -107,7 +79,7 @@
<button
type="button"
class="flex w-full items-center justify-between gap-2 px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
class="flex w-full items-center justify-between gap-2 rounded-sm px-2 py-2 text-left transition-colors hover:bg-accent disabled:cursor-not-allowed disabled:opacity-50"
onclick={() => !hasError && toggleServerForChat(server.id)}
disabled={hasError}
>
@@ -147,7 +119,7 @@
{#snippet footer()}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={onSettingsClick}
onclick={handleMcpSettingsClick}
>
<Settings class="h-4 w-4" />
@@ -155,6 +127,21 @@
</DropdownMenu.Item>
{/snippet}
</DropdownMenuSearchable>
</DropdownMenu.Content>
</DropdownMenu.Root>
{/if}
{:else}
<div class="px-2 py-3 text-center text-sm text-muted-foreground">
No MCP servers configured
</div>
<DropdownMenu.Separator />
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={handleMcpSettingsClick}
>
<Plus class="h-4 w-4" />
<span>Add MCP Servers</span>
</DropdownMenu.Item>
{/if}
</DropdownMenu.SubContent>
</DropdownMenu.Sub>
@@ -7,20 +7,13 @@
interface Props {
canSend?: boolean;
disabled?: boolean;
isLoading?: boolean;
showErrorState?: boolean;
tooltipLabel?: string;
}
let {
canSend = false,
disabled = false,
isLoading = false,
showErrorState = false,
tooltipLabel
}: Props = $props();
let { canSend = false, disabled = false, showErrorState = false, tooltipLabel }: Props = $props();
let isDisabled = $derived(!canSend || disabled || isLoading);
let isDisabled = $derived(!canSend || disabled);
</script>
{#snippet submitButton(props = {})}
@@ -0,0 +1,146 @@
<script lang="ts">
import { PencilRuler, ChevronDown, ChevronRight, Loader2, Info } from '@lucide/svelte';
import { Checkbox } from '$lib/components/ui/checkbox';
import * as Collapsible from '$lib/components/ui/collapsible';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import { toolsStore } from '$lib/stores/tools.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { useToolsPanel } from '$lib/hooks/use-tools-panel.svelte';
const toolsPanel = useToolsPanel();
const hasMcpServersAvailable = $derived(mcpStore.getServersSorted().length > 0);
</script>
<DropdownMenu.Sub onOpenChange={(open) => open && toolsPanel.handleOpen()}>
<DropdownMenu.SubTrigger class="flex cursor-pointer items-center gap-2">
<PencilRuler class="h-4 w-4" />
<span>Tools</span>
</DropdownMenu.SubTrigger>
<DropdownMenu.SubContent class="w-72 p-0">
{#if toolsPanel.totalToolCount === 0}
{#if toolsStore.loading}
<div class="px-3 py-4 text-center text-sm text-muted-foreground">
<Loader2 class="mx-auto mb-1 h-4 w-4 animate-spin" />
Loading tools...
</div>
{:else if toolsStore.isToolsEndpointUnreachable}
<div class="grid gap-2.5 px-3 py-4 text-sm text-muted-foreground">
<span class="flex gap-2">
<Info class="mt-0.5 h-4 w-4 shrink-0" />
<span
>Run llama-server with <code>--tools</code> flag to enable
<strong>Built-in Tools</strong>.</span
>
</span>
<span class="flex gap-2">
<Info class="mt-0.5 h-4 w-4 shrink-0" />
<span
>{hasMcpServersAvailable ? 'Enable' : 'Add'} MCP Server(s) to access
<strong>MCP Tools</strong>.</span
>
</span>
</div>
{:else if toolsStore.error}
<div class="px-3 py-4 text-center text-sm text-muted-foreground">Failed to load tools</div>
{:else if toolsPanel.noToolsInfoMessage}
<div class="flex gap-2 px-3 py-4 text-sm text-muted-foreground">
<Info class="mt-0.5 h-4 w-4 shrink-0" />
<span>{toolsPanel.noToolsInfoMessage}</span>
</div>
{:else}
<div class="px-3 py-4 text-center text-sm text-muted-foreground">No tools available</div>
{/if}
{:else}
<div class="max-h-80 overflow-y-auto p-2 pr-1">
{#each toolsPanel.activeGroups as group (group.label)}
{@const isExpanded = toolsPanel.expandedGroups.has(group.label)}
{@const { checked, indeterminate } = toolsPanel.getGroupCheckedState(group)}
{@const favicon = toolsPanel.getFavicon(group)}
<Collapsible.Root
open={isExpanded}
onOpenChange={() => toolsPanel.toggleGroupExpanded(group.label)}
>
<div class="flex items-center gap-1">
<Collapsible.Trigger
class="flex min-w-0 flex-1 items-center gap-2 rounded px-2 py-1.5 text-sm hover:bg-muted/50"
>
{#if isExpanded}
<ChevronDown class="h-3.5 w-3.5 shrink-0" />
{:else}
<ChevronRight class="h-3.5 w-3.5 shrink-0" />
{/if}
<span class="inline-flex min-w-0 items-center gap-1.5 font-medium">
{#if favicon}
<img
src={favicon}
alt=""
class="h-4 w-4 shrink-0 rounded-sm"
onerror={(e) => {
(e.currentTarget as HTMLImageElement).style.display = 'none';
}}
/>
{/if}
<span class="truncate">{group.label}</span>
</span>
<span class="ml-auto shrink-0 text-xs text-muted-foreground">
{toolsPanel.getEnabledToolCount(group)}/{group.tools.length}
</span>
</Collapsible.Trigger>
<Tooltip.Root>
<Tooltip.Trigger>
<Checkbox
{checked}
{indeterminate}
onCheckedChange={() => toolsStore.toggleGroup(group)}
class="mr-2 h-4 w-4 shrink-0"
/>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>
{checked ? 'Disable' : 'Enable'}
{group.tools.length} tool{group.tools.length !== 1 ? 's' : ''}
</p>
</Tooltip.Content>
</Tooltip.Root>
</div>
<Collapsible.Content>
<div class="ml-4 flex flex-col gap-0.5 border-l border-border/50 pl-2">
{#each group.tools as tool (tool.function.name)}
<button
type="button"
class="flex w-full items-center gap-2 rounded px-2 py-1.5 text-left text-sm transition-colors hover:bg-muted/50"
onclick={() => toolsStore.toggleTool(tool.function.name)}
>
<Checkbox
checked={toolsStore.isToolEnabled(tool.function.name)}
onCheckedChange={() => toolsStore.toggleTool(tool.function.name)}
class="h-4 w-4 shrink-0"
/>
<span class="min-w-0 flex-1 truncate font-mono text-[12px]">
{tool.function.name}
</span>
</button>
{/each}
</div>
</Collapsible.Content>
</Collapsible.Root>
{/each}
</div>
{/if}
</DropdownMenu.SubContent>
</DropdownMenu.Sub>
@@ -6,21 +6,19 @@
ChatFormActionAttachmentsSheet,
ChatFormActionRecord,
ChatFormActionSubmit,
McpServersSelector,
ModelsSelector,
ModelsSelectorDropdown,
ModelsSelectorSheet
} from '$lib/components/app';
import { SETTINGS_SECTION_TITLES } from '$lib/constants';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { getChatSettingsDialogContext } from '$lib/contexts';
import { FileTypeCategory } from '$lib/enums';
import { getFileTypeCategory } from '$lib/utils';
import { config } from '$lib/stores/settings.svelte';
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isRouterMode, serverError } from '$lib/stores/server.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { config } from '$lib/stores/settings.svelte';
import { activeMessages, conversationsStore } from '$lib/stores/conversations.svelte';
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import { getFileTypeCategory } from '$lib/utils';
import { goto } from '$app/navigation';
interface Props {
canSend?: boolean;
@@ -165,7 +163,8 @@
return '';
});
let selectorModelRef: ModelsSelector | ModelsSelectorSheet | undefined = $state(undefined);
let selectorModelRef: ModelsSelectorDropdown | ModelsSelectorSheet | undefined =
$state(undefined);
let isMobile = new IsMobile();
@@ -173,8 +172,6 @@
selectorModelRef?.open();
}
const chatSettingsDialog = getChatSettingsDialogContext();
let hasMcpPromptsSupport = $derived.by(() => {
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
@@ -200,8 +197,8 @@
{onFileUpload}
{onSystemPromptClick}
{onMcpPromptClick}
onMcpSettingsClick={() => goto('#/settings/mcp')}
{onMcpResourcesClick}
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
/>
{:else}
<ChatFormActionAttachmentsDropdown
@@ -214,17 +211,12 @@
{onSystemPromptClick}
{onMcpPromptClick}
{onMcpResourcesClick}
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
onMcpSettingsClick={() => goto('#/settings/mcp')}
/>
{/if}
<McpServersSelector
{disabled}
onSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
/>
</div>
<div class="ml-auto flex items-center gap-1.5">
<div class="ml-auto flex items-center gap-2">
{#if isMobile.current}
<ModelsSelectorSheet
disabled={disabled || isOffline}
@@ -234,7 +226,7 @@
useGlobalSelection
/>
{:else}
<ModelsSelector
<ModelsSelectorDropdown
disabled={disabled || isOffline}
bind:this={selectorModelRef}
currentModel={conversationModel}
@@ -244,7 +236,7 @@
{/if}
</div>
{#if isLoading}
{#if isLoading && !hasText}
<Button
type="button"
variant="secondary"
@@ -263,7 +255,6 @@
<ChatFormActionSubmit
canSend={canSend && hasModelSelected && isSelectedModelInCache}
{disabled}
{isLoading}
tooltipLabel={submitTooltip}
showErrorState={hasModelSelected && !isSelectedModelInCache}
/>
@@ -0,0 +1,182 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import * as Tooltip from '$lib/components/ui/tooltip';
import {
ATTACHMENT_FILE_ITEMS,
ATTACHMENT_EXTRA_ITEMS,
ATTACHMENT_MCP_ITEMS,
ATTACHMENT_TOOLTIP_TEXT,
TOOLTIP_DELAY_DURATION
} from '$lib/constants';
import { AttachmentMenuItemId } from '$lib/enums';
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
interface Props {
class?: string;
disabled?: boolean;
hasAudioModality?: boolean;
hasVisionModality?: boolean;
hasMcpPromptsSupport?: boolean;
hasMcpResourcesSupport?: boolean;
onFileUpload?: () => void;
onSystemPromptClick?: () => void;
onMcpPromptClick?: () => void;
onMcpSettingsClick?: () => void;
onMcpResourcesClick?: () => void;
}
let {
class: className = '',
disabled = false,
hasAudioModality = false,
hasVisionModality = false,
hasMcpPromptsSupport = false,
hasMcpResourcesSupport = false,
onFileUpload,
onSystemPromptClick,
onMcpPromptClick,
onMcpSettingsClick,
onMcpResourcesClick
}: Props = $props();
let dropdownOpen = $state(false);
function handleMcpSettingsClick() {
dropdownOpen = false;
onMcpSettingsClick?.();
}
const attachmentMenu = useAttachmentMenu(
() => ({ hasVisionModality, hasAudioModality, hasMcpPromptsSupport, hasMcpResourcesSupport }),
() => ({ onFileUpload, onSystemPromptClick, onMcpPromptClick, onMcpResourcesClick }),
() => {
dropdownOpen = false;
}
);
</script>
<div class="flex items-center gap-1 {className}">
<DropdownMenu.Root bind:open={dropdownOpen}>
<DropdownMenu.Trigger name="Attach files" {disabled}>
<Tooltip.Root>
<Tooltip.Trigger class="w-full">
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
variant="secondary"
type="button"
>
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
<Plus class="h-4 w-4" />
</Button>
</Tooltip.Trigger>
<Tooltip.Content>
<p>{ATTACHMENT_TOOLTIP_TEXT}</p>
</Tooltip.Content>
</Tooltip.Root>
</DropdownMenu.Trigger>
<DropdownMenu.Content align="start" class="w-48">
{#each ATTACHMENT_FILE_ITEMS as item (item.id)}
{@const enabled = attachmentMenu.isItemEnabled(item.enabledWhen)}
{#if enabled}
<DropdownMenu.Item
class="{item.class ?? ''} flex cursor-pointer items-center gap-2"
onclick={() => attachmentMenu.callbacks[item.action]()}
>
<item.icon class="h-4 w-4" />
<span>{item.label}</span>
</DropdownMenu.Item>
{:else if item.disabledTooltip}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="{item.class ?? ''} flex cursor-pointer items-center gap-2"
disabled
>
<item.icon class="h-4 w-4" />
<span>{item.label}</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>{item.disabledTooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
{/each}
{#if !attachmentMenu.isItemEnabled('hasVisionModality')}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={attachmentMenu.callbacks.onFileUpload}
>
{@const pdfItem = ATTACHMENT_FILE_ITEMS.find(
(i) => i.id === AttachmentMenuItemId.PDF
)}
{#if pdfItem}
<pdfItem.icon class="h-4 w-4" />
<span>{pdfItem.label}</span>
{/if}
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>PDFs will be converted to text. Image-based PDFs may not work properly.</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
<DropdownMenu.Separator />
{#each ATTACHMENT_EXTRA_ITEMS as item (item.id)}
{#if item.id === AttachmentMenuItemId.SYSTEM_MESSAGE}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger class="w-full">
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => attachmentMenu.callbacks[item.action]()}
>
<item.icon class="h-4 w-4" />
<span>{item.label}</span>
</DropdownMenu.Item>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>{attachmentMenu.getSystemMessageTooltip()}</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
{/each}
<ChatFormActionToolsSubmenu />
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
{#if attachmentMenu.isItemVisible(item.visibleWhen)}
<DropdownMenu.Item
class="flex cursor-pointer items-center gap-2"
onclick={() => attachmentMenu.callbacks[item.action]()}
>
<item.icon class="h-4 w-4" />
<span>{item.label}</span>
</DropdownMenu.Item>
{/if}
{/each}
</DropdownMenu.Content>
</DropdownMenu.Root>
</div>
@@ -0,0 +1,184 @@
<script lang="ts">
import { Plus } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as Tooltip from '$lib/components/ui/tooltip';
import * as Sheet from '$lib/components/ui/sheet';
import { TOOLTIP_DELAY_DURATION } from '$lib/constants';
import {
ATTACHMENT_FILE_ITEMS,
ATTACHMENT_EXTRA_ITEMS,
ATTACHMENT_MCP_ITEMS,
ATTACHMENT_TOOLTIP_TEXT
} from '$lib/constants/attachment-menu';
import { ChatFormActionToolsSubmenu, ChatFormActionMcpServersSubmenu } from '$lib/components/app';
import { useAttachmentMenu } from '$lib/hooks/use-attachment-menu.svelte';
import { AttachmentMenuItemId } from '$lib/enums';
interface Props {
class?: string;
disabled?: boolean;
hasAudioModality?: boolean;
hasVisionModality?: boolean;
hasMcpPromptsSupport?: boolean;
hasMcpResourcesSupport?: boolean;
onFileUpload?: () => void;
onSystemPromptClick?: () => void;
onMcpPromptClick?: () => void;
onMcpSettingsClick?: () => void;
onMcpResourcesClick?: () => void;
}
let {
class: className = '',
disabled = false,
hasAudioModality = false,
hasVisionModality = false,
hasMcpPromptsSupport = false,
hasMcpResourcesSupport = false,
onFileUpload,
onSystemPromptClick,
onMcpPromptClick,
onMcpSettingsClick,
onMcpResourcesClick
}: Props = $props();
let sheetOpen = $state(false);
const attachmentMenu = useAttachmentMenu(
() => ({ hasVisionModality, hasAudioModality, hasMcpPromptsSupport, hasMcpResourcesSupport }),
() => ({ onFileUpload, onSystemPromptClick, onMcpPromptClick, onMcpResourcesClick }),
() => {
sheetOpen = false;
}
);
function handleMcpSettingsClick() {
sheetOpen = false;
onMcpSettingsClick?.();
}
const sheetItemClass =
'flex w-full items-center gap-3 rounded-md px-3 py-2.5 text-left text-sm transition-colors hover:bg-accent active:bg-accent disabled:cursor-not-allowed disabled:opacity-50';
</script>
<div class="flex items-center gap-1 {className}">
<Sheet.Root bind:open={sheetOpen}>
<Button
class="file-upload-button h-8 w-8 rounded-full p-0"
{disabled}
variant="secondary"
type="button"
onclick={() => (sheetOpen = true)}
>
<span class="sr-only">{ATTACHMENT_TOOLTIP_TEXT}</span>
<Plus class="h-4 w-4" />
</Button>
<Sheet.Content side="bottom" class="max-h-[85vh] gap-0 overflow-y-auto">
<Sheet.Header>
<Sheet.Title>Add to chat</Sheet.Title>
<Sheet.Description class="sr-only">
Add files, system prompt or configure MCP servers
</Sheet.Description>
</Sheet.Header>
<div class="flex flex-col gap-1 px-1.5 pb-2">
{#each ATTACHMENT_FILE_ITEMS as item (item.id)}
{@const enabled = attachmentMenu.isItemEnabled(item.enabledWhen)}
{#if enabled}
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[item.action]()}
>
<item.icon class="h-4 w-4 shrink-0" />
<span>{item.label}</span>
</button>
{:else if item.disabledTooltip}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger>
<button type="button" class={sheetItemClass} disabled>
<item.icon class="h-4 w-4 shrink-0" />
<span>{item.label}</span>
</button>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>{item.disabledTooltip}</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
{/each}
{#if !attachmentMenu.isItemEnabled('hasVisionModality')}
{@const pdfItem = ATTACHMENT_FILE_ITEMS.find((i) => i.id === AttachmentMenuItemId.PDF)}
{#if pdfItem}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[pdfItem.action]()}
>
<pdfItem.icon class="h-4 w-4 shrink-0" />
<span>{pdfItem.label}</span>
</button>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>PDFs will be converted to text. Image-based PDFs may not work properly.</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
{/if}
{#each ATTACHMENT_EXTRA_ITEMS as item (item.id)}
{#if item.id === AttachmentMenuItemId.SYSTEM_MESSAGE}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
<Tooltip.Trigger>
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[item.action]()}
>
<item.icon class="h-4 w-4 shrink-0" />
<span>{item.label}</span>
</button>
</Tooltip.Trigger>
<Tooltip.Content side="right">
<p>{attachmentMenu.getSystemMessageTooltip()}</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
{/each}
<div class="my-2 border-t"></div>
<ChatFormActionToolsSubmenu />
<ChatFormActionMcpServersSubmenu onMcpSettingsClick={handleMcpSettingsClick} />
{#each ATTACHMENT_MCP_ITEMS as item (item.id)}
{#if attachmentMenu.isItemVisible(item.visibleWhen)}
<button
type="button"
class={sheetItemClass}
onclick={() => attachmentMenu.callbacks[item.action]()}
>
<item.icon class="h-4 w-4 shrink-0" />
<span>{item.label}</span>
</button>
{/if}
{/each}
</div>
</Sheet.Content>
</Sheet.Root>
</div>
@@ -1,12 +1,12 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { base } from '$app/paths';
import { getChatActionsContext, setMessageEditContext } from '$lib/contexts';
import { chatStore, pendingEditMessageId } from '$lib/stores/chat.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { DatabaseService } from '$lib/services/database.service';
import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants';
import { MessageRole, AttachmentType } from '$lib/enums';
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import {
ChatMessageAssistant,
ChatMessageUser,
@@ -118,7 +118,7 @@
const conversationDeleted = await chatStore.removeSystemPromptPlaceholder(message.id);
if (conversationDeleted) {
goto(`${base}/`);
goto(`#/`);
}
return;
@@ -138,7 +138,7 @@
const conversationDeleted = await chatStore.removeSystemPromptPlaceholder(message.id);
if (conversationDeleted) {
goto(`${base}/`);
goto(`#/`);
}
} else {
chatActions.delete(message);
@@ -200,7 +200,7 @@
const conversationDeleted = await chatStore.removeSystemPromptPlaceholder(message.id);
isEditing = false;
if (conversationDeleted) {
goto(`${base}/`);
goto(`#/`);
}
return;
}
@@ -252,70 +252,72 @@
}
</script>
{#if message.role === MessageRole.SYSTEM}
<ChatMessageSystem
bind:textareaElement
class={className}
{deletionInfo}
{message}
onConfirmDelete={handleConfirmDelete}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
/>
{:else if mcpPromptExtra}
<ChatMessageMcpPrompt
class={className}
{deletionInfo}
{message}
mcpPrompt={mcpPromptExtra}
onConfirmDelete={handleConfirmDelete}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
/>
{:else if message.role === MessageRole.USER}
<ChatMessageUser
class={className}
{deletionInfo}
{message}
onConfirmDelete={handleConfirmDelete}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onForkConversation={handleForkConversation}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
/>
{:else}
<ChatMessageAssistant
bind:textareaElement
class={className}
{deletionInfo}
{isLastAssistantMessage}
{message}
{toolMessages}
messageContent={message.content}
onConfirmDelete={handleConfirmDelete}
onContinue={handleContinue}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onForkConversation={handleForkConversation}
onNavigateToSibling={handleNavigateToSibling}
onRegenerate={handleRegenerate}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
/>
{/if}
<div use:fadeInView>
{#if message.role === MessageRole.SYSTEM}
<ChatMessageSystem
bind:textareaElement
class={className}
{deletionInfo}
{message}
onConfirmDelete={handleConfirmDelete}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
/>
{:else if mcpPromptExtra}
<ChatMessageMcpPrompt
class={className}
{deletionInfo}
{message}
mcpPrompt={mcpPromptExtra}
onConfirmDelete={handleConfirmDelete}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
/>
{:else if message.role === MessageRole.USER}
<ChatMessageUser
class={className}
{deletionInfo}
{message}
onConfirmDelete={handleConfirmDelete}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onForkConversation={handleForkConversation}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
/>
{:else}
<ChatMessageAssistant
bind:textareaElement
class={className}
{deletionInfo}
{isLastAssistantMessage}
{message}
{toolMessages}
messageContent={message.content}
onConfirmDelete={handleConfirmDelete}
onContinue={handleContinue}
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onForkConversation={handleForkConversation}
onNavigateToSibling={handleNavigateToSibling}
onRegenerate={handleRegenerate}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
{siblingInfo}
/>
{/if}
</div>
@@ -0,0 +1,23 @@
<script lang="ts">
import type { Snippet, Component } from 'svelte';
interface Props {
icon: Component<{ class?: string }>;
message: Snippet;
actions: Snippet;
}
let { icon: Icon, message, actions }: Props = $props();
</script>
<div class="my-2 rounded-lg border border-border bg-card p-3">
<div class="mb-3 flex items-center gap-2 text-sm">
<Icon class="h-4 w-4 shrink-0 text-muted-foreground" />
<span>
{@render message()}
</span>
</div>
<div class="flex flex-wrap items-center gap-2">
{@render actions()}
</div>
</div>
@@ -1,38 +1,104 @@
<script lang="ts">
import { Wrench, Loader2, Brain } from '@lucide/svelte';
import {
ChatMessageStatistics,
CollapsibleContentBlock,
MarkdownContent,
SyntaxHighlightedCode
SyntaxHighlightedCode,
ChatMessagePermissionRequest,
ChatMessageContinueRequest
} from '$lib/components/app';
import { config } from '$lib/stores/settings.svelte';
import { Wrench, Loader2, Brain } from '@lucide/svelte';
import { AgenticSectionType, FileTypeText } from '$lib/enums';
import { formatJsonPretty } from '$lib/utils';
import {
AgenticSectionType,
ChatMessageStatsView,
FileTypeText,
ToolPermissionDecision
} from '$lib/enums';
import type {
ChatMessageAgenticTimings,
ChatMessageAgenticTurnStats,
DatabaseMessage
} from '$lib/types';
import {
deriveAgenticSections,
formatJsonPretty,
parseToolResultWithImages,
type AgenticSection,
type ToolResultLine
} from '$lib/utils';
import type { DatabaseMessage } from '$lib/types/database';
import type { ChatMessageAgenticTimings, ChatMessageAgenticTurnStats } from '$lib/types/chat';
import { ChatMessageStatsView } from '$lib/enums';
import {
agenticPendingPermissionRequest,
agenticResolvePermission,
agenticPendingContinueRequest,
agenticResolveContinue
} from '$lib/stores/agentic.svelte';
import { config } from '$lib/stores/settings.svelte';
interface Props {
message: DatabaseMessage;
toolMessages?: DatabaseMessage[];
isStreaming?: boolean;
isLastAssistantMessage?: boolean;
highlightTurns?: boolean;
}
let { message, toolMessages = [], isStreaming = false, highlightTurns = false }: Props = $props();
let {
message,
toolMessages = [],
isStreaming = false,
isLastAssistantMessage = false,
highlightTurns = false
}: Props = $props();
let expandedStates: Record<number, boolean> = $state({});
const showToolCallInProgress = $derived(config().showToolCallInProgress as boolean);
const showThoughtInProgress = $derived(config().showThoughtInProgress as boolean);
let permissionDismissed = $state(false);
const pendingPermission = $derived(
isStreaming && isLastAssistantMessage ? agenticPendingPermissionRequest(message.convId) : null
);
// Reset dismissed when pendingPermission changes (new request or cleared)
let prevPendingRef: typeof pendingPermission = null;
$effect(() => {
if (pendingPermission !== prevPendingRef) {
prevPendingRef = pendingPermission;
if (pendingPermission) {
permissionDismissed = false;
}
}
});
function handlePermission(decision: ToolPermissionDecision) {
permissionDismissed = true;
agenticResolvePermission(message.convId, decision);
}
let continueDismissed = $state(false);
const pendingContinue = $derived(
isStreaming && isLastAssistantMessage ? agenticPendingContinueRequest(message.convId) : false
);
let prevContinueRef = false;
$effect(() => {
if (pendingContinue !== prevContinueRef) {
prevContinueRef = pendingContinue;
if (pendingContinue) {
continueDismissed = false;
}
}
});
function handleContinue(shouldContinue: boolean) {
continueDismissed = true;
agenticResolveContinue(message.convId, shouldContinue);
}
const sections = $derived(deriveAgenticSections(message, toolMessages, [], isStreaming));
// Parse tool results with images
@@ -201,7 +267,11 @@
<Loader2 class="h-3 w-3 animate-spin" />
{/if}
</div>
{#if section.toolResult}
{#if isPending}
<div class="rounded bg-muted/30 p-2 text-xs text-muted-foreground italic">
Waiting for result...
</div>
{:else if section.toolResult}
<div class="overflow-auto rounded-lg border border-border bg-muted p-4">
{#each section.parsedLines as line, i (i)}
<div class="font-mono text-xs leading-relaxed whitespace-pre-wrap">{line.text}</div>
@@ -215,10 +285,8 @@
{/if}
{/each}
</div>
{:else if isPending}
<div class="rounded bg-muted/30 p-2 text-xs text-muted-foreground italic">
Waiting for result...
</div>
{:else}
<div class="rounded bg-muted/30 p-2 text-xs text-muted-foreground italic">No output</div>
{/if}
</div>
</CollapsibleContentBlock>
@@ -289,6 +357,18 @@
{@render renderSection(section, index)}
{/each}
{/if}
{#if pendingPermission && !permissionDismissed}
<ChatMessagePermissionRequest
toolName={pendingPermission.toolName}
serverLabel={pendingPermission.serverLabel}
onDecision={handlePermission}
/>
{/if}
{#if pendingContinue && !continueDismissed}
<ChatMessageContinueRequest onDecision={handleContinue} />
{/if}
</div>
<style>
@@ -4,7 +4,7 @@
ChatMessageActions,
ChatMessageStatistics,
ModelBadge,
ModelsSelector
ModelsSelectorDropdown
} from '$lib/components/app';
import { getMessageEditContext } from '$lib/contexts';
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
@@ -308,6 +308,7 @@
{message}
{toolMessages}
isStreaming={isChatStreaming()}
{isLastAssistantMessage}
highlightTurns={highlightAgenticTurns}
/>
{/if}
@@ -336,10 +337,10 @@
class="inline-flex flex-wrap items-start gap-2 text-xs text-muted-foreground"
>
{#if isRouter}
<ModelsSelector
<ModelsSelectorDropdown
currentModel={displayedModel}
disabled={isLoading()}
onModelChange={async (modelId, modelName) => {
onModelChange={async (modelId: string, modelName: string) => {
const status = modelsStore.getModelStatus(modelId);
if (status !== ServerModelStatus.LOADED) {
@@ -0,0 +1,30 @@
<script lang="ts">
import { RotateCw } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import ChatMessageActionCard from './ChatMessageActionCard.svelte';
interface Props {
onDecision: (shouldContinue: boolean) => void;
}
let { onDecision }: Props = $props();
</script>
<ChatMessageActionCard icon={RotateCw}>
{#snippet message()}
Agentic turn limit reached. Continue?
{/snippet}
{#snippet actions()}
<Button size="sm" onclick={() => onDecision(true)}>Continue</Button>
<Button
variant="destructive"
size="sm"
class="text-destructive hover:text-destructive"
onclick={() => onDecision(false)}
>
Stop
</Button>
{/snippet}
</ChatMessageActionCard>
@@ -0,0 +1,88 @@
<script lang="ts">
import { ChevronDown, ShieldQuestion } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import * as ButtonGroup from '$lib/components/ui/button-group';
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
import { ToolSource, ToolPermissionDecision } from '$lib/enums';
import { TOOL_SERVER_LABELS } from '$lib/constants';
import { toolsStore } from '$lib/stores/tools.svelte';
import ChatMessageActionCard from './ChatMessageActionCard.svelte';
interface Props {
toolName: string;
serverLabel: string;
onDecision: (decision: ToolPermissionDecision) => void;
}
let { toolName, serverLabel, onDecision }: Props = $props();
</script>
<ChatMessageActionCard icon={ShieldQuestion}>
{#snippet message()}
Allow use of
<span class="font-semibold">{toolName}</span>
{#if serverLabel}
from <span class="font-semibold">{serverLabel}</span>
{/if}
?
{/snippet}
{#snippet actions()}
<DropdownMenu.Root>
<ButtonGroup.Root
class="overflow-hidden rounded-md bg-foreground text-white shadow-sm dark:bg-secondary dark:text-foreground"
>
<Button
class="rounded-none! shadow-none!"
size="sm"
onclick={() => onDecision(ToolPermissionDecision.ONCE)}
>
Allow once
</Button>
<ButtonGroup.Separator />
<DropdownMenu.Trigger>
<Button size="sm" class="rounded-none! !ps-2 shadow-none!">
<ChevronDown class="h-3.5 w-3.5" />
</Button>
</DropdownMenu.Trigger>
</ButtonGroup.Root>
<DropdownMenu.Content align="start" class="min-w-[8rem]">
<DropdownMenu.Item onclick={() => onDecision(ToolPermissionDecision.ALWAYS)}>
Always allow <pre>{toolName}</pre>
tool
</DropdownMenu.Item>
{#if serverLabel}
<DropdownMenu.Item onclick={() => onDecision(ToolPermissionDecision.ALWAYS_SERVER)}>
Always allow all tools from {serverLabel}
</DropdownMenu.Item>
{:else}
{@const source = toolsStore.getToolSource(toolName)}
{@const providerName =
source === ToolSource.BUILTIN
? TOOL_SERVER_LABELS[ToolSource.BUILTIN]
: source === ToolSource.CUSTOM
? TOOL_SERVER_LABELS[ToolSource.CUSTOM]
: 'MCP Tools'}
<DropdownMenu.Item onclick={() => onDecision(ToolPermissionDecision.ALWAYS_SERVER)}>
Approve all tools from {providerName}
</DropdownMenu.Item>
{/if}
</DropdownMenu.Content>
</DropdownMenu.Root>
<Button
variant="destructive"
size="sm"
class="text-destructive hover:text-destructive"
onclick={() => onDecision(ToolPermissionDecision.DENY)}
>
Deny
</Button>
{/snippet}
</ChatMessageActionCard>
@@ -1,11 +1,9 @@
<script lang="ts">
import { Card } from '$lib/components/ui/card';
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
import { getMessageEditContext } from '$lib/contexts';
import { config } from '$lib/stores/settings.svelte';
import ChatMessageActions from './ChatMessageActions.svelte';
import ChatMessageEditForm from './ChatMessageEditForm.svelte';
import { MessageRole } from '$lib/enums';
import ChatMessageUserBubble from './ChatMessageUserBubble.svelte';
interface Props {
class?: string;
@@ -44,34 +42,6 @@
// Get contexts
const editCtx = getMessageEditContext();
let isMultiline = $state(false);
let messageElement: HTMLElement | undefined = $state();
const currentConfig = config();
$effect(() => {
if (!messageElement || !message.content.trim()) return;
if (message.content.includes('\n')) {
isMultiline = true;
return;
}
const resizeObserver = new ResizeObserver((entries) => {
for (const entry of entries) {
const element = entry.target as HTMLElement;
const estimatedSingleLineHeight = 24; // Typical line height for text-md
isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5;
}
});
resizeObserver.observe(messageElement);
return () => {
resizeObserver.disconnect();
};
});
</script>
<div
@@ -82,29 +52,11 @@
{#if editCtx.isEditing}
<ChatMessageEditForm />
{:else}
{#if message.extra && message.extra.length > 0}
<div class="mb-2 max-w-[80%]">
<ChatAttachmentsList attachments={message.extra} readonly imageHeight="h-80" />
</div>
{/if}
{#if message.content.trim()}
<Card
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 text-foreground backdrop-blur-md data-[multiline]:py-2.5 dark:bg-primary/15"
data-multiline={isMultiline ? '' : undefined}
style="max-height: var(--max-message-height); overflow-wrap: anywhere; word-break: break-word;"
>
{#if currentConfig.renderUserContentAsMarkdown}
<div bind:this={messageElement}>
<MarkdownContent class="markdown-user-content -my-4" content={message.content} />
</div>
{:else}
<span bind:this={messageElement} class="text-md whitespace-pre-wrap">
{message.content}
</span>
{/if}
</Card>
{/if}
<ChatMessageUserBubble
content={message.content}
attachments={message.extra}
renderMarkdown={true}
/>
{#if message.timestamp}
<div class="max-w-[80%]">
@@ -0,0 +1,76 @@
<script lang="ts">
import { Card } from '$lib/components/ui/card';
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
import { config } from '$lib/stores/settings.svelte';
import type { DatabaseMessageExtra } from '$lib/types/database';
interface Props {
content: string;
attachments?: DatabaseMessageExtra[];
renderMarkdown?: boolean;
textColorClass?: string;
cardBgClass?: string;
maxHeightStyle?: string;
}
let {
content,
attachments = [],
renderMarkdown = false,
textColorClass = 'text-foreground',
cardBgClass = 'dark:bg-primary/15',
maxHeightStyle = 'max-height: var(--max-message-height);'
}: Props = $props();
let isMultiline = $state(false);
let messageElement: HTMLElement | undefined = $state();
const currentConfig = config();
$effect(() => {
if (!messageElement || !content.trim()) return;
if (content.includes('\n')) {
isMultiline = true;
return;
}
const resizeObserver = new ResizeObserver((entries) => {
for (const entry of entries) {
const element = entry.target as HTMLElement;
const estimatedSingleLineHeight = 24; // Typical line height for text-md
isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5;
}
});
resizeObserver.observe(messageElement);
return () => {
resizeObserver.disconnect();
};
});
</script>
{#if attachments && attachments.length > 0}
<div class="mb-2 max-w-[80%]">
<ChatAttachmentsList {attachments} readonly imageHeight="h-40" />
</div>
{/if}
{#if content.trim()}
<Card
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-[multiline]:py-2.5 {cardBgClass}"
data-multiline={isMultiline ? '' : undefined}
style="{maxHeightStyle} overflow-wrap: anywhere; word-break: break-word;"
>
{#if renderMarkdown && currentConfig.renderUserContentAsMarkdown}
<div bind:this={messageElement}>
<MarkdownContent class="markdown-user-content -my-4" {content} />
</div>
{:else}
<span bind:this={messageElement} class="text-md whitespace-pre-wrap">
{content}
</span>
{/if}
</Card>
{/if}
@@ -0,0 +1,71 @@
<script lang="ts">
import { ActionIcon } from '$lib/components/app';
import ChatMessageEditForm from './ChatMessageEditForm.svelte';
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { ArrowUp, Edit, Trash2 } from '@lucide/svelte';
import { getProcessingInfoContext } from '$lib/contexts';
import { useMessageEditContext } from '$lib/hooks/use-message-edit-context.svelte';
import ChatMessageUserBubble from './ChatMessageUserBubble.svelte';
interface Props {
class?: string;
content: string;
extras?: DatabaseMessageExtra[];
onSendImmediately: () => void;
onEdit: (newContent: string, extras?: DatabaseMessageExtra[]) => void;
onDelete: () => void;
}
let {
class: className = '',
content,
extras = [],
onSendImmediately,
onEdit,
onDelete
}: Props = $props();
const processingInfoCtx = getProcessingInfoContext();
let showProcessingInfo = $derived(processingInfoCtx.showProcessingInfo);
const editCtx = useMessageEditContext({
getContent: () => content,
getExtras: () => extras,
onSave: (content, extras) => onEdit(content, extras)
});
</script>
<div
use:fadeInView
aria-label="Pending user message"
class="group flex flex-col items-end gap-3 transition-opacity hover:opacity-80 md:gap-2 {className} sticky {showProcessingInfo
? 'bottom-44'
: 'bottom-32'}"
role="group"
>
{#if editCtx.isEditing}
<ChatMessageEditForm />
{:else}
<ChatMessageUserBubble
{content}
attachments={extras}
textColorClass="text-muted-foreground"
cardBgClass="dark:bg-primary/8"
maxHeightStyle="overflow-wrap: anywhere; word-break: break-word;"
/>
<div class="max-w-[80%]">
<div class="relative flex h-6 items-center justify-between">
<div class="right-0 flex items-center gap-2 opacity-100 transition-opacity">
<div
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-hover:opacity-100"
>
<ActionIcon icon={Edit} tooltip="Edit" onclick={editCtx.handleEdit} />
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
<ActionIcon icon={ArrowUp} tooltip="Send immediately" onclick={onSendImmediately} />
</div>
</div>
</div>
</div>
{/if}
</div>
@@ -1,11 +1,23 @@
<script lang="ts">
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { ChatMessage } from '$lib/components/app';
import ChatMessageUserPending from './ChatMessageUserPending.svelte';
import { setChatActionsContext } from '$lib/contexts';
import { MessageRole } from '$lib/enums';
import { chatStore } from '$lib/stores/chat.svelte';
import {
chatPendingMessageContent,
chatPendingMessageExtras,
chatClearPendingMessage,
chatInjectPendingMessage
} from '$lib/stores/chat.svelte';
import { conversationsStore, activeConversation } from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import {
agenticPendingSteeringMessageContent,
agenticPendingSteeringMessageExtras,
agenticClearSteeringMessage,
agenticInjectSteeringMessage
} from '$lib/stores/agentic.svelte';
import {
copyToClipboard,
formatMessageForClipboard,
@@ -14,12 +26,11 @@
} from '$lib/utils';
interface Props {
class?: string;
messages?: DatabaseMessage[];
onUserAction?: () => void;
}
let { class: className, messages = [], onUserAction }: Props = $props();
let { messages = [], onUserAction }: Props = $props();
let allConversationMessages = $state<DatabaseMessage[]>([]);
const currentConfig = config();
@@ -196,19 +207,42 @@
});
</script>
<div
class="flex h-full flex-col space-y-10 pt-24 {className}"
style="height: auto; min-height: calc(100dvh - 14rem);"
>
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
<div use:fadeInView>
<ChatMessage
class="mx-auto w-full max-w-[48rem]"
{message}
{toolMessages}
{isLastAssistantMessage}
{siblingInfo}
/>
</div>
{/each}
</div>
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
<ChatMessage
class="mx-auto mt-12 w-full max-w-[48rem]"
{message}
{toolMessages}
{isLastAssistantMessage}
{siblingInfo}
/>
{/each}
{#if activeConversation() && agenticPendingSteeringMessageContent(activeConversation()!.id)}
{@const convId = activeConversation()!.id}
{@const pendingContent = agenticPendingSteeringMessageContent(convId)}
{#if pendingContent}
<ChatMessageUserPending
class="mx-auto mt-12 w-full max-w-[48rem]"
content={pendingContent}
extras={agenticPendingSteeringMessageExtras(convId)}
onSendImmediately={() => chatStore.abortCurrentFlow(convId)}
onEdit={(newContent, extras) => agenticInjectSteeringMessage(convId, newContent, extras)}
onDelete={() => agenticClearSteeringMessage(convId)}
/>
{/if}
{:else if activeConversation() && chatPendingMessageContent(activeConversation()!.id)}
{@const convId = activeConversation()!.id}
{@const pendingContent = chatPendingMessageContent(convId)}
{#if pendingContent}
<ChatMessageUserPending
class="mx-auto mt-12 w-full max-w-[48rem]"
content={pendingContent}
extras={chatPendingMessageExtras(convId)}
onSendImmediately={() => chatStore.abortCurrentFlow(convId)}
onEdit={(newContent, extras) => chatInjectPendingMessage(convId, newContent, extras)}
onDelete={() => chatClearPendingMessage(convId)}
/>
{/if}
{/if}
@@ -2,7 +2,6 @@
import { afterNavigate } from '$app/navigation';
import {
ChatScreenForm,
ChatScreenHeader,
ChatMessages,
ChatScreenProcessingInfo,
DialogEmptyFileAlert,
@@ -12,15 +11,16 @@
} from '$lib/components/app';
import * as Alert from '$lib/components/ui/alert';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import { KeyboardKey } from '$lib/enums';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
import {
chatStore,
errorDialog,
isLoading,
isChatStreaming,
isEditing,
getAddFilesHandler
getAddFilesHandler,
activeProcessingState
} from '$lib/stores/chat.svelte';
import {
conversationsStore,
@@ -34,9 +34,11 @@
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
import { ErrorDialogType } from '$lib/enums';
import { onMount } from 'svelte';
import { fade, fly, slide } from 'svelte/transition';
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { Trash2, AlertTriangle, RefreshCw } from '@lucide/svelte';
import ChatScreenDragOverlay from './ChatScreenDragOverlay.svelte';
import { page } from '$app/state';
import { setProcessingInfoContext } from '$lib/contexts';
let { showCenteredEmpty = false } = $props();
@@ -79,6 +81,18 @@
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
let showProcessingInfo = $derived(
isCurrentConversationLoading ||
(config().keepStatsVisible && !!page.params.id) ||
activeProcessingState() !== null
);
setProcessingInfoContext({
get showProcessingInfo() {
return showProcessingInfo;
}
});
let isRouter = $derived(isRouterMode());
let conversationModel = $derived(
@@ -208,20 +222,13 @@
processFiles(files);
}
function handleKeydown(event: KeyboardEvent) {
const isCtrlOrCmd = event.ctrlKey || event.metaKey;
if (
isCtrlOrCmd &&
event.shiftKey &&
(event.key === KeyboardKey.D_LOWER || event.key === KeyboardKey.D_UPPER)
) {
event.preventDefault();
const { handleKeydown } = useKeyboardShortcuts({
deleteActiveConversation: () => {
if (activeConversation()) {
showDeleteDialog = true;
}
}
}
});
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
if (draft.message || draft.files.length > 0) {
@@ -342,9 +349,9 @@
<svelte:window onkeydown={handleKeydown} />
<ChatScreenHeader />
{#if !isEmpty}
{#if isServerLoading}
<ServerLoadingSplash />
{:else}
<div
bind:this={chatScrollContainer}
aria-label="Chat interface with file drop zone"
@@ -356,26 +363,42 @@
onscroll={handleScroll}
role="main"
>
<div class="flex flex-col">
<ChatMessages
class="mb-16 md:mb-24"
messages={activeMessages()}
onUserAction={() => {
autoScroll.enable();
autoScroll.scrollToBottom();
}}
/>
<div class="flex grow flex-col pt-14">
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onUserAction={() => {
autoScroll.enable();
autoScroll.scrollToBottom();
}}
/>
{/if}
<div
class="pointer-events-none sticky right-0 bottom-4 left-0 mt-auto"
in:slide={{ duration: 150, axis: 'y' }}
class="pointer-events-none {isEmpty
? 'absolute bottom-[calc(50dvh-7rem)]'
: 'sticky bottom-4'} right-4 left-4 mt-auto pt-16 transition-all duration-200"
>
<ChatScreenProcessingInfo />
{#if isEmpty}
<div class="mb-8 px-4 text-center" use:fadeInView={{ duration: 300 }}>
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
<p class="text-muted-foreground md:text-lg">
{serverStore.props?.modalities?.audio
? 'Record audio, type a message '
: 'Type a message'} or upload files to get started
</p>
</div>
{/if}
{#if page.params.id}
<ChatScreenProcessingInfo />
{/if}
{#if hasPropsError}
<div
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
in:fly={{ y: 10, duration: 250 }}
use:fadeInView={{ y: 10, duration: 250 }}
>
<Alert.Root variant="destructive">
<AlertTriangle class="h-4 w-4" />
@@ -412,69 +435,6 @@
</div>
</div>
</div>
{:else if isServerLoading}
<!-- Server Loading State -->
<ServerLoadingSplash />
{:else}
<div
aria-label="Welcome screen with file drop zone"
class="flex h-full items-center justify-center"
ondragenter={handleDragEnter}
ondragleave={handleDragLeave}
ondragover={handleDragOver}
ondrop={handleDrop}
role="main"
>
<div class="w-full max-w-[48rem] px-4">
<div class="mb-10 text-center" in:fade={{ duration: 300 }}>
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">llama.cpp</h1>
<p class="text-muted-foreground md:text-lg">
{serverStore.props?.modalities?.audio
? 'Record audio, type a message '
: 'Type a message'} or upload files to get started
</p>
</div>
{#if hasPropsError}
<div class="mb-4" in:fly={{ y: 10, duration: 250 }}>
<Alert.Root variant="destructive">
<AlertTriangle class="h-4 w-4" />
<Alert.Title class="flex items-center justify-between">
<span>Server unavailable</span>
<button
onclick={() => serverStore.fetch()}
disabled={isServerLoading}
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
>
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
{isServerLoading ? 'Retrying...' : 'Retry'}
</button>
</Alert.Title>
<Alert.Description>{serverError()}</Alert.Description>
</Alert.Root>
</div>
{/if}
<div in:fly={{ y: 10, duration: 250, delay: hasPropsError ? 0 : 300 }}>
<ChatScreenForm
disabled={hasPropsError}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
showHelperText
bind:uploadedFiles
/>
</div>
</div>
</div>
{/if}
<!-- File Upload Error Alert Dialog -->
@@ -575,21 +535,3 @@
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
/>
<style>
.conversation-chat-form {
position: relative;
&::after {
content: '';
position: absolute;
bottom: 0;
z-index: -1;
left: 0;
right: 0;
width: 100%;
height: 2.375rem;
background-color: var(--background);
}
}
</style>

Some files were not shown because too many files have changed in this diff Show More