Compare commits

..

20 Commits

Author SHA1 Message Date
bssrdf ecbf01d441 add tensor type checking as part of cuda graph properties (#19186) 2026-01-30 12:57:52 +08:00
s8322 1025fd2c09 sycl: implement GGML_UNARY_OP_SOFTPLUS (#19114)
* sycl: add softplus unary op implementation

* sycl: add softplus unary op implementation

* docs(ops): mark SYCL SOFTPLUS as supported

* docs: update SYCL status for SOFTPLUS
2026-01-30 12:01:38 +08:00
RachelMantel c7358ddf64 sycl: implement GGML_OP_TRI (#19089)
* sycl: implement GGML_OP_TRI

* docs: update ops.md for SYCL TRI

* docs: regenerate ops.md

* docs: update SYCL support for GGML_OP_TRI
2026-01-30 12:00:49 +08:00
DDXDB d284baf1b5 Fix typos in SYCL documentation (#19162)
* Fix typos in SYCL documentation

* Update SYCL.md

* Update SYCL.md

* Update SYCL.md

* Update docs/backend/SYCL.md

Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>

* Update SYCL.md

---------

Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
2026-01-30 09:46:57 +08:00
Zheyuan Chen bd90fc74c3 ggml-webgpu: improve flastAttention performance by software pipelining (#19151)
* webgpu : pipeline flash_attn Q/K loads in WGSL

* ggml-webgpu: unroll Q*K accumlation inner loop

* ggml-webgpu: vectorization

* ggml-webgpu: unrolling

* ggml-webgpu: remove redundant unrolling

* ggml-webgpu: restore the config

* ggml-webgpu: remove redundant comments

* ggml-webgpu: formatting

* ggml-webgpu: formatting and remove vectorization

* ggml-webgpu: remove unnecessary constants

* ggml-webgpu: change QKV buffer to read_write to pass validation

* ggml-webgpu: add explanation for the additional bracket around Q K accumulate

* Indentation and for -> if for tail

* Kick off CI on wgsl only commits

---------

Co-authored-by: Reese Levine <reeselevine1@gmail.com>
2026-01-29 14:05:30 -08:00
Todor Boinovski ce38a4db47 hexagon: enable offloading to Hexagon on Windows on Snapdragon (#19150)
* hexagon: updates to enable offloading to HTP on WoS

* Update windows.md

* Update windows.md

* hexagon: enable -O3 optimizations

* hexagon: move all _WINDOWS conditional compilation to _WIN32

* hexagon: updates to enable offloading to HTP on WoS

* hexagon: use run-time vs load-time dynamic linking for cdsp driver interface

* refactor htp-drv

* hexagon: add run-bench.ps1 script

* hexagon: htdrv refactor

* hexagon: unify Android and Windows build readmes

* hexagon: update README.md

* hexagon: refactor htpdrv

* hexagon: drv refactor

* hexagon: more drv refactor

* hexagon: fixes for android builds

* hexagon: factor out dl into ggml-backend-dl

* hexagon: add run-tool.ps1 script

* hexagon: merge htp-utils in htp-drv and remove unused code

* wos: no need for getopt_custom.h

* wos: add missing CR in htpdrv

* hexagon: ndev enforecement applies only to the Android devices

* hexagon: add support for generating and signing .cat file

* hexagon: add .inf file

* hexagon: working auto-signing and improved windows builds

* hexagon: futher improve skel build

* hexagon: add rough WoS guide

* hexagon: updated windows guide

* hexagon: improve cmake handling of certs and logging

* hexagon: improve windows setup/build doc

* hexagon: more windows readme updates

* hexagon: windows readme updates

* hexagon: windows readme updates

* hexagon: windows readme updates

* hexagon: windows readme updates

* Update windows.md

* Update windows.md

* snapdragon: rename docs/backend/hexagon to docs/backends/snapdragon

Also added a power shell script to simplify build env setup.

* hexagon: remove trailing whitespace and move cmake requirement to user-presets

* hexagon: fix CMakeUserPresets path in workflow yaml

* hexagon: introduce local version of libdl.h

* hexagon: fix src1 reuse logic

gpt-oss needs a bigger lookahead window.
The check for src[1] itself being quantized was wrong.

---------

Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
2026-01-29 12:33:21 -08:00
Georgi Gerganov 4fdbc1e4db cuda : fix nkvo, offload and cuda graph node properties matching (#19165)
* cuda : fix nkvo

* cont : more robust cuda graph node property matching

* cont : restore pre-leafs implementation

* cont : comments + static_assert
2026-01-29 18:45:30 +02:00
Aldehir Rojas 7b7ae857f6 chat : add parsing for solar-open-100b (#18540)
* chat : add parsing for solar-open-100b

* add comments to rules

* cont : make assistant start optional

* cont : remove assistant start prefix altogether

---------

Co-authored-by: Piotr Wilkin (ilintar) <piotr.wilkin@syndatis.com>
2026-01-29 16:06:15 +01:00
Andrew Marshall 84b0a98319 webui: Update Svelte to fix effect_update_depth_exceeded errors (#19144)
The upstream fix is first available in 5.38.2, so constrain to at least
that version.

Rebuild pre-compiled webui index.html.gz based on these changes.

See also:
https://github.com/ggml-org/llama.cpp/issues/16347
https://github.com/huntabyte/bits-ui/issues/1687
https://github.com/sveltejs/svelte/issues/16548
2026-01-29 15:56:39 +01:00
Sigbjørn Skjæret b45ef2702c jinja : do not pass empty tools and add some none filters (#19176) 2026-01-29 14:06:54 +01:00
yulo f3dd7b8e68 HIP: add mmf for CDNA (#18896)
* refactor mmf rows_per_block

* speed up compile

* pass cdna compile

* fix cuda error

* clean up mmf

* f32 mmf

* clean float mma

* fix mmf error

* faster mmf

* extend tile k

* fix compile error

* Revert "extend tile k"

This reverts commit 4d2ef3d483.

* fix smem overflow

* speed up compiling mmf

* speed up compile for hip

* 512 block for cdna

* config pad size

* fix as comment

* update select logic

* move some code to cuh

* fix as comment

* correct cdna3 config

---------

Co-authored-by: zhang hui <you@example.com>
2026-01-29 11:10:53 +01:00
Georgi Gerganov eed25bc6b0 arg : add -kvu to llama-batched-bench (#19172) 2026-01-29 08:50:47 +02:00
Vishal Singh b33df266d0 ggml-zendnn : resolve ZenDNN backend cross-module symbol dependency (#19159) 2026-01-29 12:28:57 +08:00
Aman Gupta 3bcc990997 CUDA: refactor topk-moe to enable more models (GLM 4.7, Nemotron etc.) (#19126) 2026-01-29 10:31:28 +08:00
Neo Zhang d4964a7c66 sycl: fix norm kernels: l2_norm, group_norm, rms_norm by remove assert to support more cases (#19154)
Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
2026-01-29 09:20:22 +08:00
Sigbjørn Skjæret 50e8962f79 ci : find latest release with asset for winget (#19161) 2026-01-28 22:05:39 +01:00
Ruben Ortlam f6b533d898 Vulkan Flash Attention Coopmat1 Refactor (#19075)
* vulkan: use coopmat for flash attention p*v matrix multiplication

* fix P loading issue

* fix barrier position

* remove reduction that is no longer needed

* move max thread reduction into loop

* remove osh padding

* add bounds checks and padding

* remove unused code

* fix shmem sizes, loop duration and accesses

* don't overwrite Qf, add new shared psh buffer instead

* add missing bounds checks

* use subgroup reductions

* optimize

* move bounds check, reduce barriers

* support other Bc values and other subgroup sizes

* remove D_split

* replace Of register array with shared memory Ofsh array

* parallelize HSV across the rowgroups

* go back to Of in registers, not shmem

* vectorize sfsh

* don't store entire K tile in shmem

* fixes

* load large k tiles to shmem on Nvidia

* adapt shared memory host check function to shader changes

* remove Bc 32 case

* remove unused variable

* fix missing mask reduction tmspsh barrier

* fix mask bounds check

* fix rowmax f16 under/overflow to inf

* fix flash_attn_cm2 BLOCK_SIZE preprocessor directives
2026-01-28 18:52:45 +01:00
Sascha Rogmann 72d3b1898a spec : add self‑speculative decoding (no draft model required) + refactor (#18471)
* server: introduce self-speculative decoding

* server: moved self-call into speculative.cpp

* can_speculate() includes self-speculation

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server: can_speculate() tests self-spec

* server: replace can_speculate() with slot.can_speculate()

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* common: use %zu format specifier for size_t in logging

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* server: can_speculate() requires a task instance

* common: ngram map, config self-speculative decoding

* common: add enum common_speculative_type

* common: add vector of speculative states

* common: add option --spec-draftless

* server: cleanup (remove slot.batch_spec, rename)

* common: moved self-spec impl to ngram-map

* common: cleanup (use common_speculative_state_draft)

* spec : refactor

* cont : naming

* spec: remove --spec-config

* doc: (draftless) speculative decoding

* common: print performance in spec decoding

* minor : cleanup

* common : better names

* minor : cleanup + fix build

* minor: comments

* CODEOWNERS: add common/ngram-map.* (#18471)

* common : rename speculative.draftless_type -> speculative.type

* ngram-map : fix uninitialized values

* ngram-map : take into account the input can become shorter

* ngram-map : revert len check for now

* arg : change `--spec-draftless` -> `--spec-type`

* spec : add common_speculative_state::accept()

* spec : refactor + add common_speculative_begin()

* spec : fix begin() call with mtmd

* spec : additional refactor + remove common_speculative_params

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-01-28 19:42:42 +02:00
Daniel Bevenius ebf5725870 convert : yield Mamba2Model/GraniteMoeModel modify_tensors (#19157)
* convert : yield Mamba2Model/GraniteMoeModel modify_tensors

This commit updates the `GraniteHybridModel` class' modify_tensors
function to properly delegate to `Mamba2Model.modify_tensors` and
`GraniteMoeModel.modify_tensors` using 'yield from' instead of 'return'.

The motivation for this is that modify_tensors is a generator function
(it uses 'yield from'), but the two calls above use return statements
but don't yield anything which means that the the caller of this
function will not receive any yielded values from it. And this causes
layer tensors to be silently dropped during conversion.
2026-01-28 16:49:36 +01:00
Patryk Kaminski 0cd7032ca4 ggml-sycl: remove unused syclcompat header (#19140)
The syclcompat/math.hpp is not used anymore. The change that intrduced it was successfuly reverted (https://github.com/ggml-org/llama.cpp/pull/17826).
This include path will become obsolete and dropped in oneAPI 2026.0 effectively breaking ggml-sycl builds.
2026-01-28 23:33:54 +08:00
74 changed files with 5577 additions and 2553 deletions
+5 -3
View File
@@ -21,7 +21,8 @@ on:
'**/*.m',
'**/*.metal',
'**/*.comp',
'**/*.glsl'
'**/*.glsl',
'**/*.wgsl'
]
pull_request:
@@ -42,7 +43,8 @@ on:
'**/*.m',
'**/*.metal',
'**/*.comp',
'**/*.glsl'
'**/*.glsl',
'**/*.wgsl'
]
concurrency:
@@ -1371,7 +1373,7 @@ jobs:
id: update_presets
if: ${{ matrix.build == 'arm64-snapdragon' }}
run: |
cp docs/backend/hexagon/CMakeUserPresets.json .
cp docs/backend/snapdragon/CMakeUserPresets.json .
- name: Build
id: ndk_build
+7 -6
View File
@@ -28,16 +28,17 @@ jobs:
owner: context.repo.owner,
repo: context.repo.repo,
});
console.log("Latest release:", releases[0].tag_name);
return releases[0].tag_name;
const { tag_name: version, assets: assets } = releases.find(({assets}) => assets.find(asset => asset.name.includes('win-vulkan')));
const { browser_download_url: asset_url } = assets.find(asset => asset.name.includes('win-vulkan'));
console.log("Latest release:", version);
core.setOutput('VERSION', version);
core.setOutput('ASSETURL', asset_url);
- name: Update manifest
env:
VERSION: ${{ steps.find_latest_release.outputs.result }}
run: |
echo "Updating manifest..."
komac update --version ${{ env.VERSION }} \
--urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \
komac update --version ${{ steps.find_latest_release.outputs.VERSION }} \
--urls "${{ steps.find_latest_release.outputs.ASSETURL }}" \
--token ${{ secrets.WINGET_GITHUB_TOKEN }} \
--submit \
ggml.llamacpp
+1
View File
@@ -18,6 +18,7 @@
/common/jinja/ @ngxson @CISC @aldehir
/common/llguidance.* @ggerganov
/common/log.* @ggerganov
/common/ngram-map.* @srogmann
/common/peg-parser.* @aldehir
/common/sampling.* @ggerganov
/common/speculative.* @ggerganov
+2
View File
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
log.h
ngram-cache.cpp
ngram-cache.h
ngram-map.cpp
ngram-map.h
peg-parser.cpp
peg-parser.h
preset.cpp
+75 -14
View File
@@ -6,6 +6,7 @@
#include "json-schema-to-grammar.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
#include "preset.h"
// fix problem with std::min and std::max
@@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.mmproj = res.mmproj;
}
// only download mmproj if the current example is using it
for (auto & ex : mmproj_examples) {
for (const auto & ex : mmproj_examples) {
if (ctx_arg.ex == ex) {
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
break;
}
}
common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
}
// model is required (except for server)
@@ -1216,16 +1217,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-lcs", "--lookup-cache-static"}, "FNAME",
"path to static lookup cache to use for lookup decoding (not updated by generation)",
[](common_params & params, const std::string & value) {
params.lookup_cache_static = value;
params.speculative.lookup_cache_static = value;
}
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-lcd", "--lookup-cache-dynamic"}, "FNAME",
"path to dynamic lookup cache to use for lookup decoding (updated by generation)",
[](common_params & params, const std::string & value) {
params.lookup_cache_dynamic = value;
params.speculative.lookup_cache_dynamic = value;
}
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-c", "--ctx-size"}, "N",
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
@@ -1300,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) {
params.kv_unified = value;
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED}));
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
@@ -2563,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
"Same as --hf-repo, but for the draft model (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.model.hf_repo = value;
params.speculative.mparams_dft.hf_repo = value;
}
).set_env("LLAMA_ARG_HFD_REPO"));
add_opt(common_arg(
@@ -3384,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.model.path = value;
params.speculative.mparams_dft.path = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT"));
add_opt(common_arg(
@@ -3394,6 +3395,66 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.replacements.push_back({ tgt, dft });
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
[](common_params & params, const std::string & value) {
if (value == "none") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
} else if (value == "ngram-cache") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
} else if (value == "ngram-simple") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
} else if (value == "ngram-map-k") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
} else if (value == "ngram-map-k4v") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-size-n"}, "N",
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
[](common_params & params, int value) {
if (value < 1 || value > 1024) {
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
}
params.speculative.ngram_size_n = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-size-m"}, "N",
string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
[](common_params & params, int value) {
if (value < 1 || value > 1024) {
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
}
params.speculative.ngram_size_m = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-check-rate"}, "N",
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
[](common_params & params, int value) {
if (value < 1) {
throw std::invalid_argument("ngram check rate must be at least 1");
}
params.speculative.ngram_check_rate = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-min-hits"}, "N",
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
[](common_params & params, int value) {
if (value < 1) {
throw std::invalid_argument("ngram min hits must be at least 1");
}
params.speculative.ngram_min_hits = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(
@@ -3620,8 +3681,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.port = 8012;
params.n_ubatch = 1024;
params.n_batch = 1024;
@@ -3636,8 +3697,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.port = 8012;
params.n_ubatch = 1024;
params.n_batch = 1024;
+161 -11
View File
@@ -771,10 +771,12 @@ static std::string apply(
nlohmann::ordered_json inp = nlohmann::ordered_json{
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
{"bos_token", tmpl.bos_token()},
{"eos_token", tmpl.eos_token()},
};
if (tools_override.has_value() || !inputs.tools.empty()) {
inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools;
}
if (inputs.extra_context.is_object()) {
// TODO: do we need to merge, or replacing is fine?
for (const auto & [k, v] : inputs.extra_context.items()) {
@@ -790,9 +792,6 @@ static std::string apply(
if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true;
}
if (inp["tools"].is_null()) {
inp["tools"] = json::array();
}
jinja::global_from_json(ctx, inp, inputs.mark_input);
@@ -2219,12 +2218,11 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
const std::optional<json> tools_override = json();
const std::optional<json> additional_context = json {
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
};
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, tools_override, additional_context);
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override =*/ std::nullopt, additional_context);
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@@ -2573,20 +2571,165 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// TODO: Reasoning effort
json additional_context = {};
// Copy `reasoning_content` to `reasoning`
auto adjusted_messages = json::array();
for (const auto & msg : inputs.messages) {
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
auto adjusted_message = msg;
adjusted_message["reasoning"] = msg.at("reasoning_content");
adjusted_message.erase("reasoning_content");
adjusted_messages.push_back(adjusted_message);
} else {
adjusted_messages.push_back(msg);
}
}
data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context);
data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN;
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto include_grammar = true;
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
// Check if we need to replace the flush token with end token during inference and without generation prompt.
if (inputs.is_inference && !inputs.add_generation_prompt) {
static constexpr std::string_view return_token = "<|flush|>";
static constexpr std::string_view end_token = "<|end|>";
if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) {
prompt.replace(pos, return_token.length(), end_token);
}
}
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
"<|think|>",
"<|content|>",
"<|begin|>",
"<|end|>",
"<|tool_calls|>",
"<|tool_call:begin|>",
"<|tool_call:end|>",
"<|tool_call:name|>",
"<|tool_call:args|>",
};
// TODO: Tool calling
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
auto lit_think = p.atomic(p.literal("<|think|>"));
auto lit_assistant_begin = p.atomic(p.literal("<|begin|>assistant"));
auto lit_content = p.atomic(p.literal("<|content|>"));
auto lit_end = p.atomic(p.literal("<|end|>"));
auto parser_until_end = p.until("<|end|>");
// reasoning <- "<|think|>" (!"<|end|>" .)*
auto parser_reasoning = p.rule("reasoning", lit_think + p.reasoning(parser_until_end));
// content <- "<|content|>" (!"<|end|>" .)*
auto parser_content = p.rule("content", lit_content + p.content(parser_until_end));
// wrap_choice(items) <- item-choice wrapped*
// item-choice <- items[0] / ... / items[n]
// wrapped <- "<|end|><|begin|>assistant" item-choice
auto wrap_choice = [&](const std::vector<common_peg_parser> & items) {
auto choice = p.choice(items);
return choice + p.zero_or_more(lit_end + lit_assistant_begin + choice);
};
// wrap_seq(items) <- item[0] "<|end|><|begin|>assistant" item[1] ...
auto wrap_seq = [&](const std::vector<common_peg_parser> & items) {
auto seq = p.sequence();
for (auto i = 0u; i < items.size(); i++) {
if (i == 0) {
seq += items[i];
continue;
}
seq += lit_end + lit_assistant_begin + items[i];
}
return seq;
};
// Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
auto parser_response_format = lit_content + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
return p.choice({
wrap_seq({parser_reasoning, parser_response_format}),
wrap_seq({parser_response_format})
});
}
auto lit_tool_call_begin = p.literal("<|tool_call:begin|>");
auto lit_tool_call_name = p.literal("<|tool_call:name|>");
auto lit_tool_call_args = p.literal("<|tool_call:args|>");
auto lit_tool_call_end = p.literal("<|tool_call:end|>");
// Tool call parser
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
auto parser_tool_call = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
const auto & schema = function.at("parameters");
// tool(name, schema) <- name "<|tool_call:args|>" schema
parser_tool_call |= p.rule("tool-" + name,
p.atomic(p.tool_name(p.literal(name)) + lit_tool_call_args)
+ p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
});
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
// tool-calls <- "<|tool_calls|>" tool-call+
// tool-call <- "<|tool_call:begin|> call-id "<|tool_call:name|>" &([^<]+ "<|tool_call:args|>") tool-choice "<|tool_call:end|>"
// call-id <- [a-zA-Z0-9_-]+
// tool-choice <- tool(t[0].name, t[0].schema) / ... / tool(t[n].name, t[n].schema)
auto parser_tool_calls = p.trigger_rule("tool-calls",
p.atomic(p.literal("<|tool_calls|>"))
+ p.repeat(
p.tool_open(
lit_tool_call_begin
+ p.tool_id(p.chars("[a-zA-Z0-9_-]", 1, -1))
+ lit_tool_call_name
+ p.peek(p.chars("[^<]", 1, -1) + lit_tool_call_args))
+ parser_tool_call
+ p.tool_close(lit_tool_call_end),
/* min = */ 1,
/* max = */ max_calls));
if (min_calls == 1) {
// If required, then try any combination of the reasoning, content, and tool call
return p.choice({
wrap_seq({parser_reasoning, parser_content, parser_tool_calls}),
wrap_seq({parser_reasoning, parser_tool_calls}),
wrap_seq({parser_content, parser_tool_calls}),
wrap_seq({parser_tool_calls})
});
}
return wrap_choice({parser_reasoning, parser_content, parser_tool_calls});
}
// Content only parser
include_grammar = false;
return wrap_choice({parser_reasoning, parser_content});
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
parser.build_grammar(builder, data.grammar_lazy);
});
data.grammar_triggers = {
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls|>"}
};
}
return data;
}
@@ -3043,6 +3186,13 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_apriel_1_5(tmpl, params);
}
// Solar Open
if (src.find("<|tool_response:begin|>") != std::string::npos &&
src.find("<|tool_response:name|>") != std::string::npos &&
src.find("<|tool_response:result|>") != std::string::npos) {
return common_chat_params_init_solar_open(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
+4 -5
View File
@@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) :
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
params.tensor_split,
params.tensor_buft_overrides.data(),
params.fit_params_target.data(),
params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
}
@@ -1208,10 +1211,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
void common_init_result::free_context() {
pimpl->context.reset();
}
common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));
+46 -18
View File
@@ -164,6 +164,16 @@ enum common_params_sampling_config : uint64_t {
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
};
enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
};
// sampling parameters
struct common_params_sampling {
@@ -243,16 +253,35 @@ struct common_params_model {
};
struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
// general-purpose speculative decoding parameters
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
// ngram-based speculative decoding
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
// draft-model speculative decoding
struct common_params_model mparams_dft;
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
@@ -260,7 +289,14 @@ struct common_params_speculative {
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct common_params_model model;
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
bool has_dft() const {
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
}
};
struct common_params_vocoder {
@@ -378,8 +414,6 @@ struct common_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
// llama-debug specific options
@@ -575,10 +609,6 @@ struct common_params {
// return false from callback to abort model loading or true to continue
llama_progress_callback load_progress_callback = NULL;
void * load_progress_callback_user_data = NULL;
bool has_speculative() const {
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
}
};
// call once at the start of a program if it uses libcommon
@@ -714,8 +744,6 @@ struct common_init_result {
std::vector<llama_adapter_lora_ptr> & lora();
void free_context();
private:
struct impl;
std::unique_ptr<impl> pimpl;
+10
View File
@@ -1028,6 +1028,16 @@ const func_builtins & value_none_t::get_builtins() const {
{"safe", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"strip", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"items", empty_value_fn<value_array>},
{"map", empty_value_fn<value_array>},
{"reject", empty_value_fn<value_array>},
{"rejectattr", empty_value_fn<value_array>},
{"select", empty_value_fn<value_array>},
{"selectattr", empty_value_fn<value_array>},
{"unique", empty_value_fn<value_array>},
};
return builtins;
}
+3 -4
View File
@@ -192,12 +192,12 @@ void common_ngram_cache_draft(
break;
}
LOG(" - draft candidate: token=%d\n", drafted_token);
LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
draft.push_back(drafted_token);
}
}
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
std::ofstream file_out(filename, std::ios::binary);
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
const common_ngram ngram = item.first;
@@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
}
}
}
common_ngram_cache common_ngram_cache_load(std::string & filename) {
common_ngram_cache common_ngram_cache_load(const std::string & filename) {
std::ifstream hashmap_file(filename, std::ios::binary);
if (!hashmap_file) {
throw std::ifstream::failure("Unable to open file " + filename);
+2 -2
View File
@@ -88,12 +88,12 @@ void common_ngram_cache_draft(
// Save an ngram cache to a file.
// ngram_cache: the ngram cache to save.
// filename: the path under which to save the ngram cache.
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
// Load an ngram cache saved with common_ngram_cache_save.
// filename: the path from which to load the ngram cache.
// returns: an ngram cache containing the information saved to filename.
common_ngram_cache common_ngram_cache_load(std::string & filename);
common_ngram_cache common_ngram_cache_load(const std::string & filename);
// Merge two ngram caches.
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
+367
View File
@@ -0,0 +1,367 @@
#include "common.h"
#include "log.h"
#include "ngram-map.h"
#include <cinttypes>
#include <cstdint>
#include <cstdio>
#include <sstream>
// n-gram simple
//
/**
* Perform speculative generation using the model's own token history.
* Searches for a matching pattern in the token history and returns draft tokens.
*
* @param state Current state of this implementation
* @param tokens Token history to search in
* @param sampled Last sampled token
* @return Vector of draft tokens, empty if no matching pattern is found
*/
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const llama_tokens & tokens, llama_token sampled) {
// Simple implementation of self-speculative decoding without a draft model.
//
const size_t cur_len = tokens.size();
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (state.idx_last_check + state.config.check_rate > cur_len) {
llama_tokens draft_tokens;
return draft_tokens;
}
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
// vector for tokens we want to verify.
// return empty vector if there is no match.
llama_tokens draft_tokens;
// We need at least n_draft_min + n_draft_max + 1 tokens.
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
return draft_tokens;
}
// pattern search
llama_tokens pattern;
pattern.reserve(n_draft_min);
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
pattern.push_back(tokens[j]);
}
pattern.push_back(sampled); // add the last token to the pattern
// We do a search in the token history.
state.idx_last_check = cur_len;
size_t match_pos = 0; // we ignore position 0, position 0 == no match
// search backwards, but skip the current match (we are currently there)
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < pattern.size(); ++k) {
if (tokens[j + k] != pattern[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos == 0) {
return draft_tokens;
}
const size_t copy_max = std::min(
n_draft_max,
cur_len - (match_pos + n_draft_min)
);
if (copy_max < n_draft_min) {
return draft_tokens;
}
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
__func__, cur_len,
match_pos, pattern.size(), copy_max);
draft_tokens.reserve(copy_max);
for (size_t j = 0; j < copy_max; ++j) {
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
}
return draft_tokens;
}
// n-gram map
//
// maximum number of counted values of a ngram map value.
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
void common_ngram_map_draft(common_ngram_map & map,
const llama_tokens & inp, llama_token sampled,
llama_tokens & draft) {
// reset last key and value.
map.last_draft_created = false;
map.last_draft_key_idx = 0;
map.last_draft_value_idx = 0;
const size_t cur_len = inp.size();
const uint16_t n = map.size_key;
const uint16_t m = map.size_value;
if (cur_len < static_cast<size_t>(2 * n + m)) {
return;
}
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (map.idx_last_check + map.check_rate > cur_len) {
return;
}
map.idx_last_check = cur_len;
// search pattern, the key n-gram
std::vector<llama_token> key_tokens;
key_tokens.reserve(n);
for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
key_tokens.push_back(inp[j]);
}
key_tokens.push_back(sampled);
// search for the key in the map
size_t match_pos = 0;
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < n; ++k) {
if (inp[j + k] != key_tokens[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos > 0) {
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
cur_len, n, m, key_tokens.size(), sampled, match_pos);
}
if (match_pos == 0) {
return;
}
// We have a match, now we look for the statistics of the key.
size_t key_offset = map.keys.size(); // offset in the map
// We iterate through the std::vector<common_ngram_map_key> map->keys.
for (size_t i = 0; i < map.keys.size(); ++i) {
bool match = true;
for (size_t j = 0; j < n; ++j) {
if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
match = false;
break;
}
}
if (match) {
key_offset = i;
break;
}
}
if (key_offset == map.keys.size()) {
// We create a new key-entry, it will get offset key_offset.
common_ngram_map_key new_key;
new_key.key_idx = match_pos;
new_key.stat_idx = 0;
new_key.key_num = 0;
for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
new_key.values[i].value_num = 0;
new_key.values[i].n_accepted = m;
}
map.keys.push_back(new_key);
}
// our key n-gram:
common_ngram_map_key & curr_key = map.keys[key_offset];
// update number of key hits
curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
if (map.key_only) {
// simple mode:
// Fill in the draft with the m tokens following the key.
// We work with value values[0] only.
int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
for (int i = 0; i < n_draft_tokens; ++i) {
draft.push_back(inp[match_pos + n + i]);
}
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
key_offset, curr_key.key_num, draft.size());
map.last_draft_created = false;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = 0; // value 0 is used for simple mode
return;
}
if (curr_key.key_num < map.min_hits) {
// not enough hits to consider this a good draft
LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
key_offset, curr_key.key_num, map.min_hits);
return;
}
// complex mode: examine the different m-grams after this key n-gram.
//
// determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
// begins the key n-gram at index i?
bool match_key = true;
for (size_t k = 0; k < n; ++k) {
if (inp[i + k] != key_tokens[k]) {
match_key = false;
break;
}
}
if (!match_key) {
continue;
}
// Do we haven a existing value m-gram or a new one after the key at index i?
size_t idx_begin_value_key = i + n;
int idx_value = -1;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
size_t idx_begin_value_v = curr_key.values[v].value_idx;
if (idx_begin_value_v == 0) {
// We found an empty value slot => we found a new value m-gram after the key n-gram.
curr_key.values[v].value_idx = idx_begin_value_key;
curr_key.values[v].value_num = 0;
curr_key.values[v].n_accepted = m;
idx_value = v;
break;
}
bool match = true;
for (size_t j = 0; j < m; ++j) {
if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
match = false;
break;
}
}
if (match) {
// We found an existing value m-gram after the key n-gram.
idx_value = v;
break;
}
}
if (idx_value >= 0) {
// We found a value m-gram of the key n-gram.
curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
}
}
// the statistics are updated up to match_pos.
curr_key.stat_idx = match_pos;
// Do we have a value we could use for the draft?
uint16_t max_occur = 0;
int slot_max = 0;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
uint16_t curr_occur = curr_key.values[v].value_num;
if (curr_occur > max_occur) {
max_occur = curr_occur;
slot_max = v;
}
}
// What is sum of the other occurences?
uint32_t sum_occur = 0;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
if (v == slot_max) {
continue;
}
uint16_t curr_occur = curr_key.values[v].value_num;
sum_occur += curr_occur;
}
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
key_offset,
max_occur, sum_occur, slot_max,
curr_key.values[0].value_idx, curr_key.values[0].value_num,
curr_key.values[1].value_idx, curr_key.values[1].value_num,
curr_key.values[2].value_idx, curr_key.values[2].value_num,
curr_key.values[3].value_idx, curr_key.values[3].value_num
);
// Print the tokens of the four values (if idx != 0), use LOG_INF
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
if (curr_key.values[v].value_idx != 0) {
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
}
}
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
// The most frequent value is not much more frequent than the other values.
// We do not use the draft.
return;
}
// We use the most frequent value values[slot_max] for the draft.
// Fill in the draft with the m tokens following the key.
int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
for (int i = 0; i < n_draft_tokens; ++i) {
draft.push_back(inp[match_pos + n + i]);
}
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
key_offset, slot_max,
curr_key.key_num, draft.size());
map.last_draft_created = true;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = slot_max; // value used for draft generation.
}
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
if (!map.last_draft_created) {
return;
}
// find the key and its chosen value.
const size_t key_idx = map.last_draft_key_idx;
const size_t val_idx = map.last_draft_value_idx;
// find key corresponding to key_idx.
common_ngram_map_key & curr_key = map.keys[key_idx];
// find value corresponding to val_idx.
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
// update the value statistics
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
n_accepted, curr_value.n_accepted);
curr_value.n_accepted = n_accepted;
}
// Helper functions.
//
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
std::ostringstream oss;
oss << '[';
for (size_t i = 0; i < length; ++i) {
if (i > 0) {
oss << ", ";
}
oss << inp[start + i];
}
oss << ']';
return oss.str();
}
+105
View File
@@ -0,0 +1,105 @@
#pragma once
//
// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
//
// These structures are used to do a lookup of n-grams followed by m-grams in token history.
//
// There are two algorithms implemented:
// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
//
#include "llama.h"
#include <vector>
// n-gram simple
//
// config of n-gram simple.
struct common_ngram_simple_config {
uint16_t size_ngram; // size of n-grams to lookup in self-mode
uint16_t size_mgram; // size of m-grams to draft in self-mode
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
};
// current state (and config) of n-gram simple.
struct common_ngram_simple_state {
common_ngram_simple_config config;
size_t idx_last_check = 0; // index of last check in context history (mutable)
common_ngram_simple_state(const common_ngram_simple_config & config)
: config(config) {}
};
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
// state: the ngram simple state to search in.
// inp: the tokens generated so far.
// sampled: the token that was just sampled.
// draft: vector to store the draft tokens, initially empty.
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const llama_tokens & tokens, llama_token sampled);
// n-gram map
//
// maximum number of m-gram values stored for each key n-gram.
#define COMMON_NGRAM_MAX_VALUES 4
// statistics of a m-gram after a known n-gram
struct common_ngram_map_value {
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
};
// statistics of a n-gram
struct common_ngram_map_key {
size_t key_idx; // index of key n-gram in token-history
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
uint16_t key_num; // number of occurences of this key n-gram in token-history
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
};
// map from n-grams to following m-grams in token-history
struct common_ngram_map {
uint16_t size_key; // size of key n-grams
uint16_t size_value; // size of value m-grams
bool key_only; // true if only key n-grams are used, no values.
// first draft: vector only, no map.
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
uint16_t min_hits; // minimum number of key hits to consider a draft
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
uint16_t check_rate, uint16_t min_hits)
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
check_rate(check_rate), min_hits(min_hits) {}
bool last_draft_created = false; // true if a draft was created at last call.
size_t last_draft_key_idx = 0; // index of last key used for draft generation.
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
size_t idx_last_check = 0; // index of last check in context history
};
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
// map: the ngram map to search in.
// inp: the tokens generated so far.
// sampled: the token that was just sampled.
// draft: vector to store the draft tokens, initially empty.
void common_ngram_map_draft(
common_ngram_map & map,
const llama_tokens & inp, llama_token sampled,
llama_tokens & draft);
// Update the statistics of a value after a draft was processed.
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);
+766 -246
View File
File diff suppressed because it is too large Load Diff
+23 -21
View File
@@ -5,31 +5,33 @@
struct common_speculative;
struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
// comma separated list of all types
std::string common_speculative_type_name_str();
float p_min = 0.75f; // min probability required to accept a token in the draft
};
// convert string to type
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
struct common_speculative * common_speculative_init(
struct llama_context * ctx_tgt,
struct llama_context * ctx_dft
);
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
void common_speculative_free(struct common_speculative * spec);
common_speculative * common_speculative_init(
const common_params_speculative & params,
llama_context * ctx_tgt);
bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft);
void common_speculative_free(common_speculative * spec);
void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest);
// optionally call once at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt,
llama_token id_last);
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);
+6 -3
View File
@@ -8912,13 +8912,16 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
name.endswith("block_sparse_moe.input_linear.weight")
or "shared_mlp" in name
):
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
return
# Determine whether this is a mamba layer or an attention layer
if bid in self._ssm_layers:
return Mamba2Model.modify_tensors(self, data_torch, name, bid)
yield from Mamba2Model.modify_tensors(self, data_torch, name, bid)
return
elif bid in self._attn_layers:
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
return
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
def set_gguf_parameters(self):
+7 -6
View File
@@ -35,9 +35,9 @@ The following releases are verified and recommended:
|Commit ID|Tag|Release|Verified Platform| Update date|
|-|-|-|-|-|
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |ArcB580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |Arc B580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
## News
@@ -51,7 +51,7 @@ The following releases are verified and recommended:
|-|-|-|-|
|PVC 1550|39|73|+87%|
|Flex 170|39|50|+28%|
|Arc770|42|55|+30%|
|Arc A770|42|55|+30%|
|MTL|13|16|+23%|
|ARL-H|14|17|+21%|
@@ -62,7 +62,7 @@ The following releases are verified and recommended:
- Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs.
- 2024.5
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770.
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc A770.
- Arch Linux is verified successfully.
- 2024.4
@@ -111,7 +111,8 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
|-------------------------------|---------|---------------------------------------|
| Intel Data Center Max Series | Support | Max 1550, 1100 |
| Intel Data Center Flex Series | Support | Flex 170 |
| Intel Arc Series | Support | Arc 770, 730M, Arc A750, B580 |
| Intel Arc A-Series | Support | Arc A770, Arc A730M, Arc A750 |
| Intel Arc B-Series | Support | Arc B580 |
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake |
| Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 |
@@ -1,5 +1,10 @@
{
"version": 4,
"version": 5,
"cmakeMinimumRequired": {
"major": 3,
"minor": 28,
"patch": 0
},
"configurePresets": [
{
"name": "arm64-android-snapdragon",
@@ -16,7 +21,9 @@
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
"PREBUILT_LIB_DIR": "android_aarch64",
"GGML_OPENMP": "OFF",
"GGML_LLAMAFILE": "OFF",
@@ -31,7 +38,15 @@
"name": "arm64-windows-snapdragon",
"inherits": [ "base", "arm64-windows-llvm" ],
"cacheVariables": {
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
"CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
"PREBUILT_LIB_DIR": "windows_aarch64",
"GGML_OPENMP": "OFF",
"GGML_LLAMAFILE": "OFF",
@@ -1,6 +1,8 @@
# Snapdragon-based Android devices
# Snapdragon-based devices
## How to Build
## Setup
### Android
The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain).
This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc.
@@ -12,7 +14,24 @@ This method works on Linux, macOS, and Windows. macOS and Windows users should i
[d]/> cd /workspace
```
The rest of the Android build process assumes that you're running inside the toolchain container.
Note: The rest of the **Android** build process assumes that you're running inside the toolchain container.
### Windows On Snapdragon
Native Windows 11 arm64 builds has the following tools dependencies:
- MS Visual Studio 2026 (Community Edition or Pro)
- MSVC arm64 standard and runtime libraries
- UCRT and Driver Kit
- LLVM core libraries and Clang compiler (winget)
- CMake, Git, Python (winget)
- Hexagon SDK Community Edition 6.4 or later (see windows.md)
- OpenCL SDK 2.3 or later (see windows.md)
Note: The rest of the **Windows** build process assumes that you're running natively in Powershell.
Adapt below build commands accordingly.
## How to Build
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
```
@@ -49,24 +68,26 @@ Preset CMake variables:
To generate an installable "package" simply use cmake --install:
```
[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp
[d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon/llama.cpp
-- Install configuration: "Release"
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-cpu.so
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-opencl.so
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-hexagon.so
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v73.so
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v75.so
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v79.so
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v81.so
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml.so
...
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli
-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-bench
-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-cli
...
```
## How to Install
### Android
For this step, your device needs to be configured for on-device development.
Please see https://developer.android.com/studio/debug/dev-options for details.
@@ -74,10 +95,10 @@ Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device.
**Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.**
```
~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/
pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
~/src/llama.cpp$ adb push pkg-snapdragon/llama.cpp /data/local/tmp/
pkg-snapdragon/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
pkg-snapdragon/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
pkg-snapdragon/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s)
```
@@ -92,6 +113,11 @@ At this point, you should also install some models:
Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s)
```
### Windows
All artifacts are already installed in the `pkg-snapdragon` folder.
To run, adapt below instructions to use Powershell scrits in `scripts/snapdragon/windows`.
## How to Run
The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables.
+161
View File
@@ -0,0 +1,161 @@
## Overview
The document covers procedures for installing the latest GPU and NPU drivers, and OpenCL and Hexagon SDKs.
In order to use Hexagon NPU on Snapdragon Windows devices the underlying HTP Ops libraries (e.g libggml-htp-v73.so)
must be included in the .cat file digitally signed with a trusted certificate.
This document covers details on how to generate personal certificate files (.pfx) and how to configure the system
to allow for test signatures (aka test-signing).
## Install the latest Adreno OpenCL SDK
Either use the trimmed down version (optimized for CI) from
https://github.com/snapdragon-toolchain/opencl-sdk/releases/download/v2.3.2/adreno-opencl-sdk-v2.3.2-arm64-wos.tar.xz
Or download the complete official version from
https://softwarecenter.qualcomm.com/catalog/item/Adreno_OpenCL_SDK?version=2.3.2
Unzip/untar the archive into
```
c:\Qualcomm\OpenCL_SDK\2.3.2
```
## Install the latest Hexagon SDK Community Edition
Either use the trimmed down version (optimized for CI) from
https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v6.4.0.2/hexagon-sdk-v6.4.0.2-arm64-wos.tar.xz
Or download the complete official version from
https://softwarecenter.qualcomm.com/catalog/item/Hexagon_SDK?version=6.4.0.2
Unzip/untar the archive into
```
c:\Qualcomm\Hexagon_SDK\6.4.0.2
```
## Install the latest Adreno GPU driver
Download the driver from
https://softwarecenter.qualcomm.com/catalog/item/Windows_Graphics_Driver
After the automated installation and reboot please make sure that the GPU device shows up in the `Device Manager` (under 'Display Adapters`)
## Install the latest Qualcomm NPU driver
Download the driver from
https://softwarecenter.qualcomm.com/catalog/item/Qualcomm_HND
After the automated installation and reboot please make sure that the Hexagon NPU device shows up in the `Device Manager` (under `Neural Processors`).
If the device is not available you can try installing all components (`qcnspmcdm8380`, `qcnspmcdm8380_ext`) manually.
The components are extracted into
```
c:\QCDrivers\qcnspmcdm...
```
## Enable NPU driver test signatures
Please note that the following steps are required only for the Hexagon NPU.
Adreno GPU backend does not require test signatures.
### Enable testsigning
Use `bcdedit` to enable test-signing
```
> bcdedit /set TESTSIGNING ON
```
(Secure Boot may need to be disabled for this to work)
Make sure test-signing is enabled after reboot
```
> bcdedit /enum
...
testsigning Yes
...
```
For additional details see Microsoft guide at
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/the-testsigning-boot-configuration-option
### Create personal certificate
The tools required for this procedure are available as part of Windows SDK and Windows Driver Kit which should be
installed as part of the MS Visual Studio.
They are typically located at
```
c:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0
```
(replace 10.0.26100.0 with correct version).
To create personal self-signed certificate run the following commands (either from cmd or power-shell):
```
> cd c:\Users\MyUser
> mkdir Certs
> cd Certs
> makecert -r -pe -ss PrivateCertStore -n CN=GGML.HTP.v1 -eku 1.3.6.1.5.5.7.3.3 -sv ggml-htp-v1.pvk ggml-htp-v1.cer
> pvk2pfx.exe -pvk ggml-htp-v1.pvk -spc ggml-htp-v1.cer -pfx ggml-htp-v1.pfx
```
(replace `MyUser` with your username).
Add this certificate to `Trusted Root Certification Authorities` and `Trusted Publishers` stores.
This can be done using `certlm` Certificate Manager tool.
Right click on the certificate store, select `All Tasks -> Import` and follow the prompts to import the certificate from the
PFX file you created above.
For additional details see Microsoft guide at
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/introduction-to-test-signing
Make sure to save the PFX file, you will need it for the build procedures.
Please note that the same certificate can be used for signing any number of builds.
## Build Hexagon backend with signed HTP ops libraries
The overall Hexagon backend build procedure for Windows on Snapdragon is the same as for other platforms.
However, additional settings are required for generating and signing HTP Ops libraries.
```
> $env:OPENCL_SDK_ROOT="C:\Qualcomm\OpenCL_SDK\2.3.2"
> $env:HEXAGON_SDK_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2"
> $env:HEXAGON_TOOLS_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2\tools\HEXAGON_Tools\19.0.04"
> $env:HEXAGON_HTP_CERT="c:\Users\MyUsers\Certs\ggml-htp-v1.pfx"
> $env:WINDOWS_SDK_BIN="C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\arm64"
> cmake --preset arm64-windows-snapdragon -B build-wos
...
> cmake --install build-wos --prefix pkg-snapdragon
```
Once the build is complete HTP ops libraries will be installed like this
```
> dir pkg-snapdragon/lib
...
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v73.so
-a---- 1/22/2026 6:01 PM 191752 libggml-htp-v75.so
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v79.so
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v81.so
-a---- 1/22/2026 6:01 PM 4139 libggml-htp.cat
```
The .cat file, the signature and proper certicate installation can be verified with
```
> signtool.exe verify /v /pa .\pkg-snapdragon\lib\libggml-htp.cat
Verifying: .\pkg-snapdragon\lib\libggml-htp.cat
Signature Index: 0 (Primary Signature)
Hash of file (sha256): 9820C664DA59D5EAE31DBB664127FCDAEF59CDC31502496BC567544EC2F401CF
Signing Certificate Chain:
Issued to: GGML.HTP.v1
...
Successfully verified: .\pkg-snapdragon\lib\libggml-htp.cat
...
```
+2 -2
View File
@@ -97,7 +97,7 @@ Legend:
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
@@ -114,7 +114,7 @@ Legend:
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
+763 -605
View File
File diff suppressed because it is too large Load Diff
+120
View File
@@ -0,0 +1,120 @@
# Speculative Decoding
llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model.
[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct.
## Implementations
The `llama-server` application supports several implementations of speculative decoding:
### Draft Model (`draft`)
A much smaller model (called the _draft model_) generates drafts.
A draft model is the most used approach in speculative decoding.
### n-gram Cache (`ngram-cache`)
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy.
See:
- #5479, #6828, #6848
### n-gram Map (`ngram-simple`, `ngram-map-*`)
These implementations search the token history for patterns and use matching sequences as draft candidates.
They require no additional model but rely on patterns that have already appeared in the generated text.
An example to use this approach can be the rewriting of source code by a LLM.
#### n-gram Map (`ngram-simple`)
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
#### n-gram Map Key (`ngram-map-k`)
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`) before generating drafts.
The number of accepted tokens is stored for each used n-gram.
#### n-gram Map Key-4-Values (`ngram-map-k4v`)
This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
The number of accepted tokens is stored for each used n-gram.
**Example:** Server options to be used if there are a lot of longer repetitions.
```bash
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2
```
## Command-Line Options
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
```
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]
type of speculative decoding to use when no draft model is provided
(default: none)
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
of lookup n-gram (default: 12)
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
of draft m-gram (default: 48)
--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding
(default: 1)
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
```
### `--spec-type TYPE`
Specifies a type of speculative decoding without draft model.
| Type | Description |
|------|-------------|
| `none` | No speculative decoding (default) |
| `ngram-cache` | Use n-gram cache lookup |
| `ngram-simple` | Use simple n-gram pattern matching |
| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
**Example:** Server-instance used to refactor source code.
```bash
./llama-server [...] --spec-type ngram-simple
```
### `--spec-ngram-size-n N`
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
### `--spec-ngram-size-m M`
Sets the size M of the draft m-gram for n-gram map based speculative decoding.
The m-gram size determines how many tokens to draft when a match is found.
Larger values can provide more speedup but may reduce acceptance rate.
### `--spec-ngram-check-rate R`
This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
### `--spec-ngram-min-hits H`
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
## Statistics
Each speculative decoding implementation prints statistics.
```
draft acceptance rate = 0.57576 ( 171 accepted / 297 generated)
statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73
statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
```
- `#calls`: number of calls of this implementations
- `#gen drafts`: number of drafts generated by this implementation
- `#acc drafts`: number of drafts accepted (partially) by the main model
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
- `#acc tokens`: number of tokens accepted by the main model
+2 -2
View File
@@ -32,9 +32,9 @@ int main(int argc, char ** argv){
common_ngram_cache ngram_cache;
common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str());
common_ngram_cache_save(ngram_cache, params.lookup_cache_static);
common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static);
return 0;
}
+5 -5
View File
@@ -46,18 +46,18 @@ int main(int argc, char ** argv){
{
const int64_t t_start_draft_us = ggml_time_us();
if (!params.lookup_cache_static.empty()) {
if (!params.speculative.lookup_cache_static.empty()) {
try {
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
} catch (std::ifstream::failure const &) {
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
exit(1);
}
}
if (!params.lookup_cache_dynamic.empty()) {
if (!params.speculative.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
}
+6 -6
View File
@@ -51,18 +51,18 @@ int main(int argc, char ** argv){
const int64_t t_start_draft_us = ggml_time_us();
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
if (!params.lookup_cache_static.empty()) {
if (!params.speculative.lookup_cache_static.empty()) {
try {
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
} catch (std::ifstream::failure const &) {
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
exit(1);
}
}
if (!params.lookup_cache_dynamic.empty()) {
if (!params.speculative.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
}
@@ -210,7 +210,7 @@ int main(int argc, char ** argv){
// Update dynamic ngram cache with context ngram cache and save it to disk:
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic);
LOG("\n\n");
@@ -24,7 +24,7 @@ int main(int argc, char ** argv) {
common_init();
if (params.speculative.model.path.empty()) {
if (params.speculative.mparams_dft.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
@@ -34,10 +34,8 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
llama_model * model_tgt = NULL;
//llama_model * model_dft = NULL;
llama_context * ctx_tgt = NULL;
llama_context * ctx_dft = NULL;
// load the target model
auto llama_init_tgt = common_init_from_params(params);
@@ -48,26 +46,38 @@ int main(int argc, char ** argv) {
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.model;
params.n_ctx = params.speculative.n_ctx;
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
params.n_gpu_layers = params.speculative.n_gpu_layers;
llama_model_ptr model_dft;
if (params.speculative.cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
}
// TODO: simplify this logic
{
const auto & params_spec = params.speculative;
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
auto params_dft = params;
auto llama_init_dft = common_init_from_params(params);
params_dft.n_parallel = 1;
params_dft.n_ctx = params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams_dft;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
//model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft->context();
if (params_spec.cpuparams.n_threads > 0) {
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
}
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
auto mparams_dft = common_model_params_to_llama(params_dft);
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
if (model_dft == nullptr) {
LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
return 1;
}
params.speculative.model_dft = model_dft.get();
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
}
// Tokenize the prompt
@@ -92,12 +102,6 @@ int main(int argc, char ** argv) {
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
}
// how many tokens to draft each time
int n_draft = params.speculative.n_max;
int n_draft_min = params.speculative.n_min;
float p_min = params.speculative.p_min;
int n_predict = 0;
int n_drafted = 0;
int n_accept = 0;
@@ -127,15 +131,11 @@ int main(int argc, char ** argv) {
int n_past = inp.size() - 1;
// init the speculator
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft;
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
params_spec.p_min = p_min;
const auto & params_spec = params.speculative;
struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
for (auto &pair : params.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
}
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
common_speculative_begin(spec, prompt_tgt);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
@@ -151,7 +151,7 @@ int main(int argc, char ** argv) {
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
// from a cache or lookup tables.
//
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
@@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
// do not waste time on small drafts
if (draft.size() < (size_t) n_draft_min) {
if (draft.size() < (size_t) params_spec.n_min) {
draft.clear();
}
@@ -240,7 +240,7 @@ int main(int argc, char ** argv) {
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_INF("\n");
LOG_INF("n_draft = %d\n", n_draft);
LOG_INF("n_draft = %d\n", params_spec.n_max);
LOG_INF("n_predict = %d\n", n_predict);
LOG_INF("n_drafted = %d\n", n_drafted);
LOG_INF("n_accept = %d\n", n_accept);
@@ -249,8 +249,6 @@ int main(int argc, char ** argv) {
LOG_INF("\n");
LOG_INF("draft:\n\n");
llama_perf_context_print(ctx_dft);
LOG_INF("\n");
LOG_INF("target:\n\n");
common_perf_print(ctx_tgt, smpl);
+2 -2
View File
@@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
common_init();
if (params.speculative.model.path.empty()) {
if (params.speculative.mparams_dft.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
@@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.model;
params.model = params.speculative.mparams_dft;
params.n_gpu_layers = params.speculative.n_gpu_layers;
if (params.speculative.cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
+1
View File
@@ -222,6 +222,7 @@ if (GGML_SCHED_NO_REALLOC)
endif()
add_library(ggml
ggml-backend-dl.cpp
ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml)
+48
View File
@@ -0,0 +1,48 @@
#include "ggml-backend-dl.h"
#ifdef _WIN32
dl_handle * dl_load_library(const fs::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
return handle;
}
void * dl_get_sym(dl_handle * handle, const char * name) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
void * p = (void *) GetProcAddress(handle, name);
SetErrorMode(old_mode);
return p;
}
const char * dl_error() {
return "";
}
#else
dl_handle * dl_load_library(const fs::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif
+45
View File
@@ -0,0 +1,45 @@
#pragma once
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winevt.h>
#else
# include <dlfcn.h>
# include <unistd.h>
#endif
#include <filesystem>
namespace fs = std::filesystem;
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
struct dl_handle_deleter {
void operator()(HMODULE handle) {
FreeLibrary(handle);
}
};
#else
using dl_handle = void;
struct dl_handle_deleter {
void operator()(void * handle) {
dlclose(handle);
}
};
#endif
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
dl_handle * dl_load_library(const fs::path & path);
void * dl_get_sym(dl_handle * handle, const char * name);
const char * dl_error();
+1 -66
View File
@@ -1,5 +1,6 @@
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-backend-dl.h"
#include "ggml-impl.h"
#include <algorithm>
#include <cstring>
@@ -98,72 +99,6 @@ static std::string path_str(const fs::path & path) {
}
}
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
struct dl_handle_deleter {
void operator()(HMODULE handle) {
FreeLibrary(handle);
}
};
static dl_handle * dl_load_library(const fs::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
return handle;
}
static void * dl_get_sym(dl_handle * handle, const char * name) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
void * p = (void *) GetProcAddress(handle, name);
SetErrorMode(old_mode);
return p;
}
static const char * dl_error() {
return "";
}
#else
using dl_handle = void;
struct dl_handle_deleter {
void operator()(void * handle) {
dlclose(handle);
}
};
static void * dl_load_library(const fs::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
static void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
static const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
struct ggml_backend_reg_entry {
ggml_backend_reg_t reg;
dl_handle_ptr handle;
+11 -2
View File
@@ -1122,15 +1122,18 @@ struct ggml_tensor_extra_gpu {
#endif
struct ggml_cuda_graph_node_properties {
void * node_address;
void * node_data;
ggml_op node_op;
enum ggml_type node_type;
int32_t flags;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
void * src_data[GGML_MAX_SRC];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
struct ggml_cuda_graph {
#ifdef USE_CUDA_GRAPH
~ggml_cuda_graph() {
@@ -1150,6 +1153,12 @@ struct ggml_cuda_graph {
int number_consecutive_updates = 0;
std::vector<ggml_cuda_graph_node_properties> props;
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
// they properties also have to match in order to be able to safely reuse a CUDA graph
// ref: https://github.com/ggml-org/llama.cpp/pull/18583
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
std::vector<ggml_cuda_graph_node_properties> extra;
void record_update(bool use_graph, bool update_required) {
if (use_graph && update_required) {
number_consecutive_updates++;
-5
View File
@@ -310,8 +310,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
}
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
const int cc = ggml_cuda_info().devices[device].cc;
switch (K->ne[0]) {
@@ -334,9 +332,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
if (!V_is_K_view) {
return BEST_FATTN_KERNEL_NONE;
}
break;
default:
return BEST_FATTN_KERNEL_NONE;
+269 -92
View File
@@ -70,17 +70,18 @@
#include <condition_variable>
#include <cstddef>
#include <cstdint>
#include <float.h>
#include <cfloat>
#include <initializer_list>
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
#include <unordered_set>
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
@@ -2916,22 +2917,27 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
}
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
props->node_address = node->data;
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
props->node_data = node->data;
props->node_op = node->op;
props->node_type = node->type;
props->flags = node->flags;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
props->ne[i] = node->ne[i];
props->nb[i] = node->nb[i];
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
if (!node->src[i]) {
continue;
}
props->src_data[i] = node->src[i]->data;
}
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
if (node->data != props->node_address &&
node->op != GGML_OP_VIEW) {
if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
return false;
}
@@ -2939,6 +2945,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
return false;
}
if (node->type != props->node_type) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != props->ne[i]) {
return false;
@@ -2948,12 +2958,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != props->src_address[i] &&
node->op != GGML_OP_VIEW
) {
return false;
if (node->op != GGML_OP_VIEW) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (!node->src[i]) {
if (props->src_data[i] != nullptr) {
return false;
}
continue;
}
if (node->src[i]->data != props->src_data[i]) {
return false;
}
}
}
@@ -2974,7 +2990,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
}
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
bool res = false;
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
@@ -2985,15 +3000,20 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
}
// Check if the graph size has changed
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
if (graph->props.size() != (size_t)cgraph->n_nodes) {
res = true;
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
graph->props.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
std::unordered_set<ggml_tensor *> seen_node;
std::vector<ggml_tensor *> srcs_extra;
for (int i = 0; i < cgraph->n_nodes; i++) {
bool props_match = true;
seen_node.insert(cgraph->nodes[i]);
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
}
@@ -3001,17 +3021,31 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
res = true;
}
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
if (src && seen_node.find(src) == seen_node.end()) {
srcs_extra.push_back(src);
}
}
}
for (int i = 0; i < cgraph->n_leafs; i++) {
if (graph->extra.size() != (size_t) srcs_extra.size()) {
res = true;
graph->extra.resize(srcs_extra.size());
}
for (size_t i = 0; i < srcs_extra.size(); ++i) {
bool props_match = true;
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
}
if (!props_match) {
res = true;
}
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
}
return res;
@@ -3080,63 +3114,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
return true;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
args.sigmoid = false;
args.softmax = false;
args.delayed_softmax = false;
args.prob_bias = false;
args.norm = false;
const int n_nodes = cgraph->n_nodes;
ggml_tensor ** nodes = cgraph->nodes;
if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
args.softmax = true;
}
if (nodes[node_idx]->op == GGML_OP_UNARY) {
if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
return false;
}
args.sigmoid = true;
}
if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
args.delayed_softmax = true;
}
node_idx++;
if (args.sigmoid || args.softmax) {
// SOFTMAX -> RESHAPE
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx];
node_idx++;
if (node_idx >= n_nodes) {
return false;
}
// src of bias add is the unreshaped probs (-2 instead of -1)
if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
args.prob_bias = true;
node_idx++;
}
// RESHAPE/ADD -> ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
return false;
}
if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
} else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
return false;
}
node_idx++;
// ARGSORT-> VIEW
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
return false;
}
// GET_ROWS
if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
} else if (args.delayed_softmax) {
if (node_idx - 2 < 0) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx - 2];
// VIEW->ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
// GET_ROWS
if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != probs_reshaped) {
return false;
}
node_idx++;
static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
for (const ggml_op op : remaining_ops) {
if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
}
}
// At this point we can check for norm + scale. Everything is now at least valid till the norm
if (node_idx >= n_nodes) {
return true;
}
if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
//check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
args.norm = true;
for (const ggml_op op : norm_ops) {
if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
node_idx++;
} else {
args.norm = false;
return true;
}
}
// DIV <- CLAMP, RESHAPE
if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
args.norm = false;
return true;
}
node_idx++;
if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
args.norm = false;
return true;
}
node_idx++;
}
if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
args.scale = true;
}
return true;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops,
std::initializer_list<enum ggml_unary_op> unary_ops) {
#ifndef NDEBUG
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
GGML_ASSERT(unary_ops.size() == num_unary);
#endif
//TODO: remove special case once ggml_can_fuse can handle empty nodes
std::initializer_list<enum ggml_op> topk_moe_ops =
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
const std::initializer_list<enum ggml_op> & list2) {
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
};
if (is_equal(topk_moe_ops_with_norm, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
@@ -3398,35 +3535,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
// start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
ggml_cuda_topk_moe_args args;
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 9];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_tensor * clamp = cgraph->nodes[i + 7];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false, clamp);
i += 9;
continue;
}
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i + 4];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4;
continue;
}
std::vector<ggml_op> ops;
if (ggml_cuda_can_fuse(cgraph, i,
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 5];
ggml_tensor * ids = cgraph->nodes[i + 1];
if (can_fuse) {
const ggml_tensor * logits = node->src[0];
ggml_tensor * weights = nullptr;
ggml_tensor * ids = nullptr;
const ggml_tensor * bias = nullptr;
const ggml_tensor * clamp = nullptr;
const ggml_tensor * scale = nullptr;
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
/*delayed_softmax*/ true);
i += 5;
continue;
if (!args.delayed_softmax) {
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
int out_nodes[2]; // nodes which can't be elided
if (args.prob_bias) {
bias = cgraph->nodes[i + 2]->src[1];
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS });
out_nodes[0] = i + 4;
ids = cgraph->nodes[i + 4];
} else {
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS });
out_nodes[0] = i + 3;
ids = cgraph->nodes[i + 3];
}
if (args.norm) {
ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
GGML_OP_DIV, GGML_OP_RESHAPE });
clamp = cgraph->nodes[i + ops.size() - 3];
}
if (args.scale) {
ops.insert(ops.end(), { GGML_OP_SCALE });
scale = cgraph->nodes[i + ops.size() - 1];
}
weights = cgraph->nodes[i + ops.size() - 1];
out_nodes[1] = i + ops.size() - 1;
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
} else if (!args.norm && !args.prob_bias) {
//special case gpt-oss, no norm, no bias.
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
weights = cgraph->nodes[i + 5];
ids = cgraph->nodes[i + 1];
const ggml_tensor * softmax = cgraph->nodes[i + 4];
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
}
}
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
+98 -4
View File
@@ -333,7 +333,33 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#elif defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
@@ -391,7 +417,22 @@ namespace ggml_cuda_mma {
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
}
#elif defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
@@ -945,6 +986,32 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
#ifdef AMD_MFMA_AVAILABLE
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
#if defined(CDNA3)
using floatx2_t = __attribute__((ext_vector_type(2))) float;
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
#elif defined(CDNA2) || defined(CDNA1)
#pragma unroll
for (int i = 0; i < 2; ++i) {
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
}
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(CDNA3)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMD_MFMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
@@ -1054,6 +1121,13 @@ namespace ggml_cuda_mma {
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // RDNA4
#elif defined(AMD_MFMA_AVAILABLE)
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
@@ -1081,11 +1155,31 @@ namespace ggml_cuda_mma {
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // RDNA4
#endif // defined(RDNA4)
#elif defined(AMD_MFMA_AVAILABLE)
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
#if defined(CDNA3) || defined(CDNA2)
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
#elif defined(CDNA1)
#pragma unroll
for (int i = 0; i < 2; ++i) {
using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
}
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
#endif // defined(CDNA3) || defined(CDNA2)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(AMD_WMMA_AVAILABLE)
}
template <data_layout dl_d, data_layout dl_ab>
+30 -10
View File
@@ -2,6 +2,13 @@
#include "mmf.cuh"
#include "mmid.cuh"
static __forceinline__ int mmf_get_rows_per_block(const int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return MMF_ROWS_PER_BLOCK_CDNA;
} else {
return MMF_ROWS_PER_BLOCK;
}
}
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
@@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
ids_info_ptr = &ids_info;
}
const int device = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[device].cc;
const int rows_per_block = mmf_get_rows_per_block(cc);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
mul_mat_f_switch_rows_per_block<float>(
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
mul_mat_f_switch_rows_per_block<half2>(
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
mul_mat_f_switch_rows_per_block<nv_bfloat162>(
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
@@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
return false;
}
if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
return false;
}
@@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
} else {
if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
return false;
} else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
//TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
return false;
} else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
return false;
} else if (src1_ncols > 16) {
return false;
}
@@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
return ampere_mma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_F16:
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc) || amd_wmma_available(cc);
return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
default:
return false;
}
+158 -85
View File
@@ -7,6 +7,31 @@
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
#define MMF_ROWS_PER_BLOCK_CDNA 64
static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return 512;
} else {
return 256;
}
}
static __forceinline__ int mmf_get_padding(int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return 2;
} else {
return 4;
}
}
static constexpr __device__ int mmf_get_padding() {
#if defined(AMD_MFMA_AVAILABLE)
return 2;
#else
return 4;
#endif // defined(AMD_MFMA_AVAILABLE)
}
struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr;
@@ -29,23 +54,25 @@ static __global__ void mul_mat_f(
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#elif defined(AMD_MFMA_AVAILABLE)
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
@@ -57,7 +84,7 @@ static __global__ void mul_mat_f(
}
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
@@ -198,7 +225,7 @@ static __global__ void mul_mat_f(
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) {
__syncthreads();
@@ -228,27 +255,34 @@ static __global__ void mul_mat_f(
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
float sum[rows_per_block/warp_size] = {0.0f};
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
#pragma unroll
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
const int i = i0 + i1*warp_size + threadIdx.x;
sum += buf_iw[j*kiw + i];
sum[i1] += buf_iw[j*kiw + i];
}
}
if constexpr (!has_ids) {
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
#pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
} else {
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
#pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif //VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
@@ -256,7 +290,7 @@ static __global__ void mul_mat_f(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
//This kernel is for larger batch sizes of mul_mat_id
@@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids(
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#elif defined(AMD_MFMA_AVAILABLE)
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
@@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids(
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
@@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids(
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) {
__syncthreads();
@@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids(
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
float sum[rows_per_block/warp_size] = {0.0f};
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
#pragma unroll
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
const int i = i0 + i1*warp_size + threadIdx.x;
sum += buf_iw[j*kiw + i];
sum[i1] += buf_iw[j * kiw + i];
}
}
const int global_j = col_base + j;
@@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids(
const int token = (int) qrm.x;
if (token < ncols_dst_total) {
const int slot = (int) qrm.y;
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
#pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif // VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
template<typename T, int cols_per_block, int nwarps>
template<typename T, int rows_per_block, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
@@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids(
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
@@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids(
dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles;
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
}
template <typename T, int cols_per_block>
template <typename T, int rows_per_block, int cols_per_block>
void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@@ -605,7 +645,7 @@ void mul_mat_f_cuda(
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
int64_t max_block_size = mmf_get_max_block_size(cc);
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
@@ -614,10 +654,9 @@ void mul_mat_f_cuda(
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
@@ -628,56 +667,56 @@ void mul_mat_f_cuda(
switch (nwarps_best) {
case 1: {
mul_mat_f_switch_ids<T, cols_per_block, 1>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
@@ -691,7 +730,7 @@ void mul_mat_f_cuda(
GGML_UNUSED_VARS(nchannels_y);
}
template <typename T>
template <typename T, int rows_per_block>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block(
switch (ncols_case) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
@@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block(
}
}
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
template void mul_mat_f_cuda<T, ncols_dst>( \
template <typename T>
static void mul_mat_f_switch_rows_per_block(
const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream, const mmf_ids_data * ids_data) {
switch (rows_per_block) {
case MMF_ROWS_PER_BLOCK: {
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case MMF_ROWS_PER_BLOCK_CDNA: {
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
default:
GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
}
}
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \
@@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \
DECL_MMF_CASE_HELPER(float, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
DECL_MMF_CASE_EXTERN(1);
DECL_MMF_CASE_EXTERN(2);
+191 -139
View File
@@ -5,6 +5,13 @@
#include <cmath>
#include <initializer_list>
// Kernel config struct - passed by value to CUDA kernel
struct topk_moe_config {
bool use_sigmoid;
bool with_norm;
bool delayed_softmax;
};
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
template <int experts_per_thread, bool use_limit>
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
@@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
}
}
template <int experts_per_thread, bool use_limit>
__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
}
}
/*
This kernel does the following:
1. optionally softmax over the logits per token [n_experts, n_tokens]
@@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template <int n_experts, bool with_norm, bool delayed_softmax = false>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
const int n_rows,
const int n_expert_used,
const float clamp_val) {
template <int n_experts, bool has_bias>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
float * bias,
const int n_rows,
const int n_expert_used,
const float clamp_val,
const float scale_val,
const topk_moe_config config) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) {
return;
@@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float wt[experts_per_thread];
// Initialize all slots to -INFINITY
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
}
if constexpr (!delayed_softmax) {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
if (!config.delayed_softmax) {
if (config.use_sigmoid) {
sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
} else {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
}
}
// selection_wt is only needed when bias is present (selection uses wt + bias)
// when no bias, we use wt directly for both selection and weight values
float selection_wt[has_bias ? experts_per_thread : 1];
if constexpr (has_bias) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
selection_wt[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
selection_wt[i / WARP_SIZE] =
(n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
}
}
//at this point, each thread holds either a portion of the softmax distribution
@@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float max_val = wt[0];
int max_expert = threadIdx.x;
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
if constexpr (has_bias) {
float max_val_s = selection_wt[0];
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
max_val = wt[i];
max_val_s = selection_wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
max_val = val;
max_val_s = val_s;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
selection_wt[max_expert / WARP_SIZE] = -INFINITY;
}
} else {
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
}
}
@@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[k] = max_expert;
if constexpr (with_norm) {
if (config.with_norm) {
wt_sum += max_val;
}
}
}
if constexpr (with_norm) {
if (config.with_norm) {
wt_sum = warp_reduce_sum(wt_sum);
wt_sum = max(wt_sum, clamp_val);
const float inv_sum = 1.0f / wt_sum;
@@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}
if constexpr (delayed_softmax) {
if (config.delayed_softmax) {
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
}
@@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
if (idx < n_expert_used) {
weights[idx] = output_weights[i];
weights[idx] = output_weights[i] * scale_val;
}
}
if (!with_norm) {
GGML_UNUSED(clamp_val);
}
}
template <bool with_norm, bool delayed_softmax = false>
template<bool has_bias>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
float * weights,
int32_t * ids,
float * bias,
const int n_rows,
const int n_expert,
const int n_expert_used,
const float clamp_val) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
const float clamp_val,
const float scale_val,
const topk_moe_config config) {
GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
"delayed softmax is not supported with weight normalization");
const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) {
case 1:
topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 2:
topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 4:
topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 8:
topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 16:
topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 32:
topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 64:
topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 128:
topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 256:
topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 512:
topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 576:
topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
default:
GGML_ASSERT(false && "fatal error");
@@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
}
}
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax,
ggml_tensor * clamp) {
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const ggml_tensor * clamp,
const ggml_tensor * scale,
const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const float * logits_d = (const float *) logits->data;
float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data;
float * bias_d = bias ? (float *) bias->data : nullptr;
float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
const int n_expert_used = weights->ne[1];
const bool with_norm = clamp != nullptr;
float clamp_val = -INFINITY;
if (with_norm) {
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
topk_moe_config config;
config.use_sigmoid = args.sigmoid;
config.with_norm = with_norm;
config.delayed_softmax = args.delayed_softmax;
if (bias) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
scale_val, config);
} else {
GGML_ASSERT(clamp == nullptr);
if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
}
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
scale_val, config);
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
const ggml_tensor * logits,
const ggml_tensor * ids) {
const int n_expert = ids->nb[1] / ids->nb[0];
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
return false;
}
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
return false;
}
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
if (gating_op->op == GGML_OP_SOFT_MAX) {
const ggml_tensor * softmax = gating_op;
float scale = 1.0f;
float max_bias = 0.0f;
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;
}
if (clamp) {
if (clamp->op != GGML_OP_CLAMP) {
if (!ggml_is_contiguous(softmax->src[0])) {
return false;
}
float max_val = ggml_get_op_params_f32(clamp, 1);
if (max_val != INFINITY) {
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
} else if (gating_op->op == GGML_OP_UNARY) {
ggml_unary_op op = ggml_get_unary_op(gating_op);
if (op != GGML_UNARY_OP_SIGMOID) {
return false;
}
}
return true;
}
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
GGML_ASSERT(!norm || !delayed_softmax);
if (delayed_softmax) {
return delayed_softmax_ops;
}
if (norm) {
return norm_ops;
}
return no_norm_ops;
}
+20 -14
View File
@@ -3,19 +3,25 @@
#include <initializer_list>
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
struct ggml_cuda_topk_moe_args {
bool sigmoid{};
bool softmax{};
bool delayed_softmax{};
bool prob_bias{};
bool norm{};
bool scale{};
};
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const ggml_tensor * clamp,
const ggml_tensor * scale,
const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
const ggml_tensor * logits,
const ggml_tensor * ids);
+69 -44
View File
@@ -1,7 +1,17 @@
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
endif()
message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
add_library(htp_iface OBJECT
@@ -25,56 +35,71 @@ else()
target_link_options(htp_iface PUBLIC -ldl)
endif()
link_custom_library(htp_iface cdsprpc)
link_custom_library(htp_iface rpcmem)
set(TARGET_NAME ggml-hexagon)
ggml_add_backend_library(${TARGET_NAME}
ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h)
ggml-hexagon.cpp
htp-drv.cpp
htp-drv.h
libdl.h
../../include/ggml-hexagon.h)
target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
# Build HTP bits
set(HTP_CMAKE_ARGS
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
# Build HTP skels
set(HTP_SKELS)
function(build_htp_skel V)
ExternalProject_Add(htp-${V}
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so
CMAKE_ARGS
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
-DDSP_VERSION=${V}
-DPREBUILT_LIB_DIR="toolv19_${V}")
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
endfunction()
ExternalProject_Add(htp-v68
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68")
ExternalProject_Add(htp-v69
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69")
ExternalProject_Add(htp-v73
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
ExternalProject_Add(htp-v75
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75")
ExternalProject_Add(htp-v79
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79")
ExternalProject_Add(htp-v81
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81")
build_htp_skel(v68)
build_htp_skel(v69)
build_htp_skel(v73)
build_htp_skel(v75)
build_htp_skel(v79)
build_htp_skel(v81)
# Install Hexagon skels required at runtime
install(FILES
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so
TYPE LIB)
install(FILES ${HTP_SKELS} TYPE LIB)
if (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT)
file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/arm64" WINSDK_BIN0_ARM64)
file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/x86" WINSDK_BIN0_X86)
file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/arm64" WINSDK_BIN1_ARM64)
file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/x86" WINSDK_BIN1_X86)
set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86})
find_program(INF2CAT NAMES inf2cat.exe PATHS ${WINSDK_PATHS} REQUIRED)
find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED)
message(STATUS "hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels")
set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat)
add_custom_target(libggml-htp-cat
BYPRODUCTS ${LIBGGML_HTP_CAT}
DEPENDS libggml-htp.inf ${HTP_SKELS}
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR}
COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64
COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT}
COMMENT "generating and signing libggml-htp.cat file"
VERBATIM
)
add_dependencies(${TARGET_NAME} libggml-htp-cat)
install(FILES ${LIBGGML_HTP_CAT} TYPE LIB)
endif()
+27 -32
View File
@@ -14,9 +14,6 @@
#ifdef _WIN32
# include <sal.h>
# ifndef _WINDOWS
# define _WINDOWS
# endif
#else
# include <semaphore.h>
# include <unistd.h>
@@ -25,8 +22,6 @@
#pragma clang diagnostic ignored "-Wnested-anon-types"
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
#include "htp-utils.h"
#include <AEEStdErr.h>
#include <dspqueue.h>
#include <rpcmem.h>
@@ -40,6 +35,7 @@
#include "op-desc.h"
#include "htp-msg.h"
#include "htp_iface.h"
#include "htp-drv.h"
static size_t opt_ndev = 1;
static size_t opt_nhvx = 0; // use all
@@ -150,9 +146,9 @@ void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_
0, // flags - the framework will autoset this
n_bufs, // number of buffers
bufs, // buffer references
sizeof(req),
sizeof(req), // Message length
(const uint8_t *) &req, // Message
1000000 // Timeout
DSPQUEUE_TIMEOUT // Timeout
);
if (err != 0) {
@@ -182,13 +178,13 @@ void ggml_hexagon_session::flush() {
// Read response packet from queue
int err = dspqueue_read(q, &flags,
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
&n_bufs, // Number of buffer references
bufs, // Buffer references
sizeof(rsp), // Max message length
&rsp_size, // Message length
(uint8_t *) &rsp,
1000000); // Timeout
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
&n_bufs, // Number of buffer references
bufs, // Buffer references
sizeof(rsp), // Max message length
&rsp_size, // Message length
(uint8_t *) &rsp, // Message
DSPQUEUE_TIMEOUT); // Timeout
if (err == AEE_EEXPIRED) {
// TODO: might need to bail out if the HTP is stuck on something
@@ -269,13 +265,7 @@ struct ggml_backend_hexagon_buffer_context {
ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
size += 4 * 1024; // extra page for padding
if (rpcmem_alloc2) {
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
} else {
GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str());
this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
}
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
if (!this->base) {
GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
@@ -2461,12 +2451,12 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
}
static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type));
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
}
static inline bool is_compute_op(ggml_tensor *node)
{
return !(ggml_op_is_empty(node->op) || ggml_is_empty(node));
return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
}
// scan the graph and figure out last compute op index
@@ -2488,7 +2478,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
const int last = last_compute_op(graph);
const struct ggml_tensor * prev_quant_op = nullptr; // prev executed op with quantizer
const struct ggml_tensor * prev_op = nullptr; // prev executed op
for (int i = 0; i < graph->n_nodes; ++i) {
ggml_tensor * node = graph->nodes[i];
@@ -2497,17 +2487,15 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
uint32_t flags = 0;
// skip quantizer if src1 is reused
if (op_reuse_src1(node, prev_quant_op)) {
if (op_reuse_src1(node, prev_op)) {
flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
}
prev_op = node;
// ask for early notification for the last Op
if (i == last) {
flags |= HTP_OPFLAGS_EARLY_WAKEUP;
@@ -2520,7 +2508,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
} else {
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
}
prev_quant_op = node;
break;
case GGML_OP_MUL_MAT_ID:
if (ggml_is_quantized(node->src[0]->type)) {
@@ -2528,7 +2515,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
} else {
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
}
prev_quant_op = node;
break;
case GGML_OP_MUL:
case GGML_OP_ADD:
@@ -2670,7 +2656,7 @@ static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<no
}
// that many nodes forward to search for stackable nodes that can reuse VTCM
constexpr int N_FORWARD = 8;
constexpr int N_FORWARD = 16;
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
if (used[i1]) {
@@ -3056,10 +3042,12 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
}
}
#if defined(__ANDROID__)
if (opt_arch < 75) {
opt_ndev = 1;
GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
}
#endif
GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
@@ -3156,6 +3144,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_arch = strtoul(str_arch, NULL, 0);
}
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
reg->context = new ggml_hexagon_registry(reg);
HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
@@ -3180,6 +3170,11 @@ ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
auto nErr = htpdrv_init();
if (nErr != AEE_SUCCESS) {
return NULL;
}
ggml_hexagon_init(&reg);
}
+418
View File
@@ -0,0 +1,418 @@
// sample drv interface
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wsign-compare"
#include <filesystem>
#include <set>
#include <sstream>
#include <string>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winevt.h>
#else
# include <dlfcn.h>
# include <unistd.h>
#endif
#include "ggml-impl.h"
#include "htp-drv.h"
#include "libdl.h"
#include <domain.h>
//
// Driver API types
//
typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size);
typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size);
typedef void (*rpcmem_free_pfn_t)(void * po);
typedef int (*rpcmem_to_fd_pfn_t)(void * po);
typedef AEEResult (*dspqueue_create_pfn_t)(int domain,
uint32_t flags,
uint32_t req_queue_size,
uint32_t resp_queue_size,
dspqueue_callback_t packet_callback,
dspqueue_callback_t error_callback,
void * callback_context,
dspqueue_t * queue);
typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue);
typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id);
typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags,
uint32_t num_buffers,
struct dspqueue_buffer *buffers,
uint32_t message_length,
const uint8_t *message,
uint32_t timeout_us);
typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags,
uint32_t max_buffers, uint32_t *num_buffers,
struct dspqueue_buffer *buffers,
uint32_t max_message_length,
uint32_t *message_length, uint8_t *message,
uint32_t timeout_us);
typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags);
typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length);
typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph);
typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra);
typedef int (*remote_handle64_close_pfn_t)(remote_handle h);
typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen);
typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen);
typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen);
//
// Driver API pfns
//
rpcmem_alloc_pfn_t rpcmem_alloc_pfn = nullptr;
rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr;
rpcmem_free_pfn_t rpcmem_free_pfn = nullptr;
rpcmem_to_fd_pfn_t rpcmem_to_fd_pfn = nullptr;
fastrpc_mmap_pfn_t fastrpc_mmap_pfn = nullptr;
fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr;
dspqueue_create_pfn_t dspqueue_create_pfn = nullptr;
dspqueue_close_pfn_t dspqueue_close_pfn = nullptr;
dspqueue_export_pfn_t dspqueue_export_pfn = nullptr;
dspqueue_write_pfn_t dspqueue_write_pfn = nullptr;
dspqueue_read_pfn_t dspqueue_read_pfn = nullptr;
remote_handle64_open_pfn_t remote_handle64_open_pfn = nullptr;
remote_handle64_invoke_pfn_t remote_handle64_invoke_pfn = nullptr;
remote_handle64_close_pfn_t remote_handle64_close_pfn = nullptr;
remote_handle_control_pfn_t remote_handle_control_pfn = nullptr;
remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr;
remote_session_control_pfn_t remote_session_control_pfn = nullptr;
//
// Driver API
//
void * rpcmem_alloc(int heapid, uint32_t flags, int size) {
return rpcmem_alloc_pfn(heapid, flags, size);
}
void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) {
if (rpcmem_alloc2_pfn) {
return rpcmem_alloc2_pfn(heapid, flags, size);
} else {
GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n");
return rpcmem_alloc_pfn(heapid, flags, size);
}
}
void rpcmem_free(void * po) {
return rpcmem_free_pfn(po);
}
int rpcmem_to_fd(void * po) {
return rpcmem_to_fd_pfn(po);
}
HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) {
return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags);
}
HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) {
return fastrpc_munmap_pfn(domain, fd, addr, length);
}
AEEResult dspqueue_create(int domain,
uint32_t flags,
uint32_t req_queue_size,
uint32_t resp_queue_size,
dspqueue_callback_t packet_callback,
dspqueue_callback_t error_callback,
void * callback_context,
dspqueue_t * queue) {
return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback,
callback_context, queue);
}
AEEResult dspqueue_close(dspqueue_t queue) {
return dspqueue_close_pfn(queue);
}
AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) {
return dspqueue_export_pfn(queue, queue_id);
}
AEEResult dspqueue_write(dspqueue_t queue,
uint32_t flags,
uint32_t num_buffers,
struct dspqueue_buffer * buffers,
uint32_t message_length,
const uint8_t * message,
uint32_t timeout_us) {
return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us);
}
AEEResult dspqueue_read(dspqueue_t queue,
uint32_t * flags,
uint32_t max_buffers,
uint32_t * num_buffers,
struct dspqueue_buffer * buffers,
uint32_t max_message_length,
uint32_t * message_length,
uint8_t * message,
uint32_t timeout_us) {
return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length,
message, timeout_us);
}
HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) {
return remote_handle64_open_pfn(name, ph);
}
HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) {
return remote_handle64_invoke_pfn(h, dwScalars, pra);
}
HTPDRV_API int remote_handle64_close(remote_handle64 h) {
return remote_handle64_close_pfn(h);
}
HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) {
return remote_handle_control_pfn(req, data, datalen);
}
HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) {
return remote_handle64_control_pfn(h, req, data, datalen);
}
HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) {
return remote_session_control_pfn(req, data, datalen);
}
#ifdef _WIN32
static std::string wstr_to_str(std::wstring_view wstr) {
std::string result;
if (wstr.empty()) {
return result;
}
auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
wstr.data(), (int) wstr.size(),
nullptr, 0, nullptr, nullptr);
if (bytes_needed == 0) {
GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
throw std::runtime_error("Invalid wstring input");
}
result.resize(bytes_needed, '\0');
int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
wstr.data(), (int) wstr.size(),
result.data(), bytes_needed,
nullptr, nullptr);
if (bytes_written == 0) {
GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
throw std::runtime_error("Wstring conversion failed");
}
return result;
}
static std::string get_driver_path() {
std::wstring serviceName = L"qcnspmcdm";
std::string result;
// Get a handle to the SCM database.
SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ);
if (nullptr == schSCManager) {
GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError());
return result;
}
// Get a handle to the service.
SC_HANDLE schService = OpenServiceW(schSCManager, // SCM database
serviceName.c_str(), // name of service
SERVICE_QUERY_CONFIG); // need query config access
if (nullptr == schService) {
GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError());
CloseServiceHandle(schSCManager);
return result;
}
// Store the size of buffer used as an output.
DWORD bufferSize;
if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) &&
(GetLastError() != ERROR_INSUFFICIENT_BUFFER)) {
GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return result;
}
// Get the configuration of the service.
LPQUERY_SERVICE_CONFIGW serviceConfig =
static_cast<LPQUERY_SERVICE_CONFIGW>(LocalAlloc(LMEM_FIXED, bufferSize));
if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) {
fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
LocalFree(serviceConfig);
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return result;
}
// Read the driver file path get its parent directory
std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName);
driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\"));
// Clean up resources
LocalFree(serviceConfig);
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
// Driver path would contain invalid path string, like:
// \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37
// "\SystemRoot" should be replace with a correct one (e.g. C:\Windows)
const std::wstring systemRootPlaceholder = L"\\SystemRoot";
if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) {
GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n");
return result;
}
// Replace \SystemRoot with an absolute path from system ENV windir
const std::wstring systemRootEnv = L"windir";
// Query the number of wide charactors this variable requires
DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0);
if (numWords == 0) {
GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n");
return result;
}
// Query the actual system root name from environment variable
std::vector<wchar_t> systemRoot(numWords + 1);
numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1);
if (numWords == 0) {
GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n");
return result;
}
driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data()));
return wstr_to_str(driverPath);
}
#endif
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
int htpdrv_init() {
static dl_handle_ptr lib_cdsp_rpc_handle = nullptr;
static bool initialized = false;
#ifdef _WIN32
std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll";
#else
std::string drv_path = "libcdsprpc.so";
#endif
if (initialized) {
GGML_LOG_INFO("ggml-hex: Driver already loaded\n");
return AEE_SUCCESS;
}
GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str());
fs::path path{ drv_path.c_str() };
dl_handle_ptr handle { dl_load_library(path) };
if (!handle) {
GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error());
return AEE_EUNABLETOLOAD;
}
#define dlsym(drv, type, pfn, symbol, ignore) \
do { \
pfn = (type) dl_get_sym(drv, #symbol); \
if (!ignore && nullptr == pfn) { \
GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol); \
return AEE_EUNABLETOLOAD; \
} \
} while (0)
dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false);
dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true);
dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false);
dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false);
dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false);
dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false);
dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false);
dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false);
dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false);
dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false);
dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false);
dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false);
dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false);
dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false);
dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false);
dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false);
dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false);
lib_cdsp_rpc_handle = std::move(handle);
initialized = true;
return AEE_SUCCESS;
}
domain * get_domain(int domain_id) {
int i = 0;
int size = sizeof(supported_domains) / sizeof(domain);
for (i = 0; i < size; i++) {
if (supported_domains[i].id == domain_id) {
return &supported_domains[i];
}
}
return NULL;
}
int get_hex_arch_ver(int domain, int * arch) {
if (!remote_handle_control_pfn) {
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
return AEE_EUNSUPPORTEDAPI;
}
struct remote_dsp_capability arch_ver;
arch_ver.domain = (uint32_t) domain;
arch_ver.attribute_ID = ARCH_VER;
arch_ver.capability = (uint32_t) 0;
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
return AEE_EUNSUPPORTEDAPI;
}
if (err != AEE_SUCCESS) {
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
return err;
}
switch (arch_ver.capability & 0xff) {
case 0x68:
*arch = 68;
return 0;
case 0x69:
*arch = 69;
return 0;
case 0x73:
*arch = 73;
return 0;
case 0x75:
*arch = 75;
return 0;
case 0x79:
*arch = 79;
return 0;
case 0x81:
*arch = 81;
return 0;
}
return -1;
}
+121
View File
@@ -0,0 +1,121 @@
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#ifdef _WIN32
# pragma clang diagnostic ignored "-Wignored-attributes"
#endif
#include <AEEStdErr.h>
#include <rpcmem.h>
#include <remote.h>
#include <dspqueue.h>
#if defined(_WIN32) && !defined(__MINGW32__)
# ifdef GGML_BACKEND_BUILD
# define HTPDRV_API __declspec(dllexport) extern
# else
# define HTPDRV_API __declspec(dllimport) extern
# endif
#else
# define HTPDRV_API __attribute__ ((visibility ("default"))) extern
#endif
/* Offset to differentiate HLOS and Hexagon error codes.
Stores the value of AEE_EOFFSET for Hexagon. */
#ifndef DSP_OFFSET
# define DSP_OFFSET 0x80000400
#endif
/* Errno for connection reset by peer. */
#ifndef ECONNRESET
# ifdef __hexagon__
# define ECONNRESET 104
# endif
#endif
/* Abstraction of different OS specific sleep APIs.
SLEEP accepts input in seconds. */
#ifndef SLEEP
# ifdef __hexagon__
# define SLEEP(x) \
{ /* Do nothing for simulator. */ \
}
# else
# ifdef _WIN32
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
# else
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
# endif
# endif
#endif
/* Include windows specific header files. */
#ifdef _WIN32
# include <windows.h>
# include <sysinfoapi.h>
# define _CRT_SECURE_NO_WARNINGS 1
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
#endif
/* Includes and defines for all HLOS except windows */
#if !defined(__hexagon__) && !defined(_WIN32)
# include "unistd.h"
# include <sys/time.h>
#endif
/* Includes and defines for Hexagon and all HLOS except Windows. */
#if !defined(_WIN32)
/* Weak reference to remote symbol for compilation. */
# pragma weak remote_session_control
# pragma weak remote_handle_control
# pragma weak remote_handle64_control
# pragma weak fastrpc_mmap
# pragma weak fastrpc_munmap
# pragma weak rpcmem_alloc2
#endif
#if !defined(_WIN32)
# pragma weak remote_system_request
#endif
#ifdef _WIN32
# define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE
#else
# define DSPQUEUE_TIMEOUT 1000000
#endif
/**
* htpdrv_init API: driver interface entry point
*
* @return Return AEE error codes as defined in Hexagon SDK.
*/
HTPDRV_API int htpdrv_init(void);
/**
* get_domain API: get domain struct from domain value.
*
* @param[in] domain value of a domain
* @return Returns domain struct of the domain if it is supported or else
* returns NULL.
*
*/
HTPDRV_API domain * get_domain(int domain_id);
/**
* get_hex_arch_ver API: query the Hexagon processor architecture version information
*
* @param[in] domain_id value of a domain
* @param[out] Arch version (73, 75, ...)
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
HTPDRV_API int get_hex_arch_ver(int domain, int * arch);
#ifdef __cplusplus
}
#endif
-454
View File
@@ -1,454 +0,0 @@
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wsign-compare"
#define GGML_COMMON_IMPL_C
#include "ggml-backend-impl.h"
#include "ggml-common.h"
#include "ggml-hexagon.h"
#include "ggml-impl.h"
#include "htp-utils.h"
#include <domain.h>
#include <remote.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
domain * get_domain(int domain_id) {
int i = 0;
int size = sizeof(supported_domains) / sizeof(domain);
for (i = 0; i < size; i++) {
if (supported_domains[i].id == domain_id) {
return &supported_domains[i];
}
}
return NULL;
}
bool is_valid_domain_id(int domain_id, int compute_only) {
int i = 0;
int size = sizeof(supported_domains) / sizeof(domain);
if (compute_only) {
return is_CDSP(domain_id);
}
for (i = 0; i < size; i++) {
if (supported_domains[i].id == domain_id) {
return true;
}
}
return false;
}
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) {
int nErr = AEE_SUCCESS;
int ss_info = 0;
if (domain_type != NULL) {
if (strcmp(domain_type, "LPASS") == 0) {
ss_info = FASTRPC_LPASS;
} else if (strcmp(domain_type, "HPASS") == 0) {
ss_info = FASTRPC_HPASS;
} else {
ss_info = FASTRPC_NSP;
}
}
system_req_payload req = { 0 };
req.id = FASTRPC_GET_DOMAINS;
req.sys.domains = NULL;
fastrpc_domain * domain = NULL;
if (ss_info != 0) {
req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info);
} else {
req.sys.flags = 0;
}
#ifdef _WIN32
nErr = AEE_EUNSUPPORTED;
goto bail;
#endif
if (remote_system_request) {
nErr = remote_system_request(&req);
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
goto bail;
}
// Allocate memory for domain-info array
req.sys.max_domains = req.sys.num_domains;
if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) {
nErr = AEE_ENOMEMORY;
GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains");
goto bail;
}
nErr = remote_system_request(&req);
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
goto bail;
}
for (int i = 0; i < req.sys.num_domains; i++) {
// Verify that only requested type domains were returned
domain = &req.sys.domains[i];
if (domain->type != ss_info && domain_type != NULL) {
nErr = -1;
GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n");
goto bail;
}
}
*domains_info = req.sys.domains;
*num_domains = req.sys.num_domains;
} else {
nErr = AEE_EUNSUPPORTED;
goto bail;
}
bail:
if (nErr && !req.sys.domains) {
free(req.sys.domains);
}
return nErr;
}
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) {
int err = 0;
remote_rpc_effective_domain_id_t sess = { 0 };
sess.domain_name = domain_name;
sess.domain_name_len = strlen(domain_name);
sess.session_id = session_id;
err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess));
if (err) {
GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name,
session_id);
return err;
}
*effec_domain_id = sess.effective_domain_id;
return err;
}
int get_dsp_support(int * domain) {
int nErr = AEE_SUCCESS;
*domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID
if (remote_handle_control) {
struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 };
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
goto bail;
}
if (dsp_capability_domain.capability == 0) {
dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support.
dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT;
dsp_capability_domain.capability = 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain,
sizeof(struct remote_dsp_capability));
if (dsp_capability_domain.capability) {
*domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID
}
}
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return nErr;
}
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) {
int nErr = AEE_SUCCESS;
*capability = 0;
if (attr == VTCM_PAGE || attr == VTCM_COUNT) {
} else {
nErr = AEE_EBADPARM;
GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n");
goto bail;
}
if (remote_handle_control) {
if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) {
/*
* Query the DSP for VTCM information
* Since the ADSP does not have a dedicated VTCM, we expect the output to be 0
*/
struct remote_dsp_capability dsp_capability_vtcm_dsp;
dsp_capability_vtcm_dsp.domain = (uint32_t) domain;
dsp_capability_vtcm_dsp.attribute_ID = attr;
dsp_capability_vtcm_dsp.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (nErr == AEE_SUCCESS) {
*capability = dsp_capability_vtcm_dsp.capability;
} else {
GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTED;
GGML_LOG_ERROR("Unsupported domain %d\n", domain);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return nErr;
}
bool is_unsignedpd_supported(int domain_id) {
int nErr = AEE_SUCCESS;
if (remote_handle_control) {
struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 };
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n");
return false;
}
if (nErr) {
GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr);
return false;
}
if (dsp_capability_domain.capability == 1) {
return true;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n");
return false;
}
return false;
}
bool get_unsignedpd_support(void) {
return is_unsignedpd_supported(CDSP_DOMAIN_ID);
}
bool is_async_fastrpc_supported(int domain) {
int nErr = AEE_SUCCESS;
if (remote_handle_control) {
if (domain == CDSP_DOMAIN_ID) {
/*
* Query the DSP for ASYNC_FASTRPC_SUPPORT information
* Async fastrpc is supported only on CDSP
*/
struct remote_dsp_capability dsp_capability_async_support;
dsp_capability_async_support.domain = (uint32_t) domain;
dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT;
dsp_capability_async_support.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (dsp_capability_async_support.capability == 1) {
return true;
}
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTED;
GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return false;
}
bool is_status_notification_supported(int domain) {
int nErr = AEE_SUCCESS;
if (remote_handle_control) {
/*
* Query the DSP for STATUS_NOTIFICATION_SUPPORT information
* DSP User PD status notification Support
*/
struct remote_dsp_capability dsp_capability_status_notification_support;
dsp_capability_status_notification_support.domain = (uint32_t) domain;
dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT;
dsp_capability_status_notification_support.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (dsp_capability_status_notification_support.capability == 1) {
return true;
}
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return false;
}
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) {
int nErr = AEE_SUCCESS;
*capability = 0;
if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) {
nErr = AEE_EBADPARM;
GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n");
goto bail;
}
if (remote_handle_control) {
if (domain == CDSP_DOMAIN_ID) {
/*
* Query the DSP for HMX SUPPORT information
* HMX is supported on CDSP only
*/
struct remote_dsp_capability dsp_capability_hmx_dsp;
dsp_capability_hmx_dsp.domain = (uint32_t) domain;
dsp_capability_hmx_dsp.attribute_ID = attr;
dsp_capability_hmx_dsp.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (nErr == AEE_SUCCESS) {
*capability = dsp_capability_hmx_dsp.capability;
} else {
GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTED;
GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return nErr;
}
int get_hex_arch_ver(int domain, int * arch) {
if (!remote_handle_control) {
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
return AEE_EUNSUPPORTEDAPI;
}
struct remote_dsp_capability arch_ver;
arch_ver.domain = (uint32_t) domain;
arch_ver.attribute_ID = ARCH_VER;
arch_ver.capability = (uint32_t) 0;
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
return AEE_EUNSUPPORTEDAPI;
}
if (err != AEE_SUCCESS) {
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
return err;
}
switch (arch_ver.capability & 0xff) {
case 0x68:
*arch = 68;
return 0;
case 0x69:
*arch = 69;
return 0;
case 0x73:
*arch = 73;
return 0;
case 0x75:
*arch = 75;
return 0;
case 0x79:
*arch = 79;
return 0;
case 0x81:
*arch = 81;
return 0;
}
return -1;
}
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) {
int nErr = AEE_SUCCESS;
*capability = 0;
if (remote_handle_control) {
if (domain == CDSP_DOMAIN_ID) {
/*
* Query the DSP for HVX SUPPORT information
* HVX is supported on CDSP only
*/
struct remote_dsp_capability dsp_capability_hvx_dsp;
dsp_capability_hvx_dsp.domain = (uint32_t) domain;
dsp_capability_hvx_dsp.attribute_ID = attr;
dsp_capability_hvx_dsp.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (nErr == AEE_SUCCESS) {
*capability = dsp_capability_hvx_dsp.capability;
} else {
GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTED;
GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return nErr;
}
-221
View File
@@ -1,221 +0,0 @@
#ifndef HTP_UTILS_H
#define HTP_UTILS_H
#ifdef __cplusplus
extern "C" {
#endif
#include <AEEStdErr.h>
#include <inttypes.h>
#include <remote.h>
#include <rpcmem.h>
#include <stdbool.h>
/* Offset to differentiate HLOS and Hexagon error codes.
Stores the value of AEE_EOFFSET for Hexagon. */
#ifndef DSP_OFFSET
# define DSP_OFFSET 0x80000400
#endif
/* Errno for connection reset by peer. */
#ifndef ECONNRESET
# ifdef __hexagon__
# define ECONNRESET 104
# endif
#endif
/* Abstraction of different OS specific sleep APIs.
SLEEP accepts input in seconds. */
#ifndef SLEEP
# ifdef __hexagon__
# define SLEEP(x) \
{ /* Do nothing for simulator. */ \
}
# else
# ifdef _WINDOWS
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
# else
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
# endif
# endif
#endif
/* Include windows specific header files. */
#ifdef _WINDOWS
# include <sysinfoapi.h>
# include <windows.h>
# define _CRT_SECURE_NO_WARNINGS 1
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
/* Including this file for custom implementation of getopt function. */
# include "getopt_custom.h"
#endif
/* Includes and defines for all HLOS except windows */
#if !defined(__hexagon__) && !defined(_WINDOWS)
# include "unistd.h"
# include <sys/time.h>
#endif
/* Includes and defines for Hexagon and all HLOS except Windows. */
#if !defined(_WINDOWS)
/* Weak reference to remote symbol for compilation. */
# pragma weak remote_session_control
# pragma weak remote_handle_control
# pragma weak remote_handle64_control
# pragma weak fastrpc_mmap
# pragma weak fastrpc_munmap
# pragma weak rpcmem_alloc2
#endif
#if !defined(_WINDOWS)
# pragma weak remote_system_request
#endif
/**
* Wrapper for FastRPC Capability API: query DSP support.
*
* @param[out] domain pointer to supported domain.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*/
int get_dsp_support(int * domain);
/**
* Wrapper for FastRPC Capability API: query VTCM information.
*
* @param[in] domain value of domain in the queried.
* @param[out] capability capability value of the attribute queried.
* @param[in] attr value of the attribute to the queried.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*/
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr);
/**
* Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain.
*
* @return true if unsigned pd is supported.
* false if unsigned pd is not supported, capability query failed.
*/
bool get_unsignedpd_support(void);
/**
* Wrapper for FastRPC Capability API: query unsigned pd support.
*
* @param[in] domain value of domain in the queried.
* @return true if unsigned pd is supported.
* false if unsigned pd is not supported, capability query failed.
*/
bool is_unsignedpd_supported(int domain_id);
/**
* is_valid_domain_id API: query a domain id is valid.
*
* @param[in] domain value of domain in the queried.
* @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled.
* @return true if value of domain is valid.
* false if value of domain is not valid.
*/
bool is_valid_domain_id(int domain_id, int compute_only);
/**
* get_domain API: get domain struct from domain value.
*
* @param[in] domain value of a domain
* @return Returns domain struct of the domain if it is supported or else
* returns NULL.
*
*/
domain * get_domain(int domain_id);
/**
* get_domains_info API: get information for all the domains available on the device
*
* @param[in] domain_type pointer to domain type
* @param[in] num_domains pointer to number of domains
* @param[in] domains_info pointer to save discovered domains information.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
* It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application.
*
*/
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info);
/**
* get_effective_domain_id API: get effective domain id for given session id
*
* @param[in] domain_name pointer to domain name
* @param[in] session_id
* @param[in] effec_domain_id pointer to save obtained effective domain id.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id);
/**
* is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not
*
* @param[in] domain_id value of a domain
* @return Returns true or false stating support of Async FastRPC
*
*/
bool is_async_fastrpc_supported(int domain_id);
/**
* is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information
*
* @param[in] domain_id value of a domain
* @return Returns true or false stating status notification support information
*
*/
bool is_status_notification_supported(int domain_id);
/**
* get_hmx_support_info API: query the DSP for HMX SUPPORT information
*
* @param[in] domain_id value of a domain
* @param[out] capability capability value of the attribute queried.
* @param[in] attr value of the attribute to the queried.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr);
/**
* get_hex_arch_ver API: query the Hexagon processor architecture version information
*
* @param[in] domain_id value of a domain
* @param[out] Arch version (73, 75, ...)
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
int get_hex_arch_ver(int domain, int * arch);
/**
* get_hvx_support_info API: query the DSP for HVX SUPPORT information
*
* @param[in] domain_id value of a domain
* @param[out] capability capability value of the attribute queried.
* @param[in] attr value of the attribute to the queried.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr);
#ifdef __cplusplus
}
#endif
#endif //DSP_CAPABILITIES_UTILS_H
+79
View File
@@ -0,0 +1,79 @@
#pragma once
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winevt.h>
#else
# include <dlfcn.h>
# include <unistd.h>
#endif
#include <filesystem>
namespace fs = std::filesystem;
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
struct dl_handle_deleter {
void operator()(HMODULE handle) {
FreeLibrary(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
void * p = (void *) GetProcAddress(handle, name);
SetErrorMode(old_mode);
return p;
}
static inline const char * dl_error() {
return "";
}
#else
using dl_handle = void;
struct dl_handle_deleter {
void operator()(void * handle) {
dlclose(handle);
}
};
static inline dl_handle * dl_load_library(const fs::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
static inline void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
static inline const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif
+38
View File
@@ -0,0 +1,38 @@
[Version]
Signature = "$WINDOWS NT$"
Class = ComputeAccelerator
ClassGuid = {F01A9D53-3FF6-48D2-9F97-C8A7004BE10C}
Provider = %GGML%
DriverVer = 01/01/2026,1.0.0.0
CatalogFile = libggml-htp.cat
PnpLockDown = 1
[DestinationDirs]
Drivers_Dir = 6
[SourceDisksNames]
1 = %DiskId%
[SourceDisksFiles]
libggml-htp-v68.so = 1
libggml-htp-v69.so = 1
libggml-htp-v73.so = 1
libggml-htp-v75.so = 1
libggml-htp-v81.so = 1
[ControlFlags]
ExcludeFromSelect = *
[DefaultInstall.NTarm64]
CopyFiles=Drivers_Dir
[Drivers_Dir]
libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE
[Strings]
GGML = 'GGML'
DiskId = 'GGML HTP library'
+2
View File
@@ -62,6 +62,8 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmf*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
-1
View File
@@ -15,7 +15,6 @@
#include <sycl/sycl.hpp>
#include <sycl/half_type.hpp>
#include <syclcompat/math.hpp>
#include <map>
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
+20
View File
@@ -123,6 +123,15 @@ static __dpct_inline__ T op_log(T x) {
return sycl::log(x);
}
template<typename T>
static __dpct_inline__ T op_softplus(T x) {
const float xf = (float) x;
const float ax = sycl::fabs(xf);
const float m = sycl::fmax(xf, 0.0f);
const float y = m + sycl::log1p(sycl::exp(-ax));
return (T) y;
}
template<typename T>
static __dpct_inline__ T op_neg(T x) {
return -x;
@@ -695,6 +704,12 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
});
}
static inline void ggml_sycl_op_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_softplus(x);
});
}
static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_neg(x);
@@ -1101,6 +1116,11 @@ void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_log(ctx, dst);
}
void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_softplus(ctx, dst);
}
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_neg(ctx, dst);
+2
View File
@@ -61,6 +61,8 @@ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+75 -4
View File
@@ -2263,6 +2263,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
}
static void tri_f32_sycl(
const float * src,
float * dst,
const int64_t ne0,
const int64_t ne1,
const int64_t ne2,
const int64_t ne3,
const ggml_tri_type ttype,
dpct::queue_ptr main_stream
) {
const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
const int64_t idx = (int64_t) tid[0];
const int64_t i0 = idx % ne0;
const int64_t t1 = idx / ne0;
const int64_t i1 = t1 % ne1;
bool keep = false;
switch (ttype) {
case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break;
case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break;
case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
default: keep = false; break;
}
dst[idx] = keep ? src[idx] : 0.0f;
});
}
static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(src0->data);
float * dst_dd = static_cast<float *>(dst->data);
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
const int64_t ne0 = src0->ne[0];
const int64_t ne1 = src0->ne[1];
const int64_t ne2 = src0->ne[2];
const int64_t ne3 = src0->ne[3];
tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
}
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -3786,6 +3845,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_UNARY_OP_EXP:
ggml_sycl_exp(ctx, dst);
break;
case GGML_UNARY_OP_SOFTPLUS:
ggml_sycl_softplus(ctx, dst);
break;
case GGML_UNARY_OP_SGN:
ggml_sycl_sgn(ctx, dst);
break;
@@ -3912,6 +3974,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_TRANSPOSE:
GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
break;
case GGML_OP_TRI:
ggml_sycl_op_tri(ctx, dst);
break;
case GGML_OP_DIAG_MASK_INF:
ggml_sycl_diag_mask_inf(ctx, dst);
break;
@@ -4404,6 +4469,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_ELU:
return true;
case GGML_UNARY_OP_FLOOR:
@@ -4606,18 +4672,23 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
case GGML_OP_NORM:
return true;
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM:
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
return true;
case GGML_OP_RMS_NORM_BACK:
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE:
return true;
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_TRI:
{
const ggml_tensor * src0 = op->src[0];
return src0 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0);
}
case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX:
-3
View File
@@ -251,7 +251,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
const float eps, queue_ptr stream, int device) {
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
@@ -334,7 +333,6 @@ static void group_norm_f32_sycl(const float* x, float* dst,
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@@ -374,7 +372,6 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+43 -21
View File
@@ -3162,17 +3162,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
// For scalar, use 128 (arbitrary)
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
const uint32_t D = (hsk|hsv);
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
uint32_t wg_size;
switch (path) {
case FA_COOPMAT2:
wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
break;
case FA_COOPMAT1:
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
break;
default:
wg_size = scalar_flash_attention_workgroup_size;
break;
}
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
const uint32_t D_lsb = D ^ (D & (D-1));
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
// Nvidia prefers shared memory use to load large tiles of K
// AMD prefers loading K directly from global memory
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
};
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@@ -3187,15 +3201,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} \
} \
@@ -8344,41 +8358,49 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];
const uint32_t MatBr = 16, MatBc = 16;
const uint32_t row_split = Bc / MatBc;
const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype;
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
const uint32_t qstride = hsk_pad / 4 + 2;
const uint32_t Qf = Br * qstride * f16vec4;
const uint32_t psh_stride = Br / 4 + 2;
const uint32_t Psh = Bc * psh_stride * f16vec4;
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype;
const uint32_t kshstride = hsk_pad / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4;
const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
const uint32_t vsh_stride = MatBc / 4 * row_split;
const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
const uint32_t slope = Br * sizeof(float);
const uint32_t slope = Br * acctype;
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
@@ -8442,7 +8464,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type);
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
path = FA_SCALAR;
@@ -8,6 +8,8 @@ layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -74,6 +76,10 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#ifndef BLOCK_SIZE
#define BLOCK_SIZE 1
#endif
#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4
@@ -7,6 +7,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
@@ -14,12 +15,13 @@
#include "types.glsl"
#include "flash_attn_base.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
const uint32_t row_split = 4;
const uint32_t row_split = Bc / MatBc;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -40,24 +42,24 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
shared float tmpsh[row_split];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
const uint psh_stride = Br / 4 + 2;
shared f16vec4 Psh[Bc * psh_stride];
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
const uint vsh_stride = v_cols;
shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
shared float slope[Br];
shared ACC_TYPE slope[Br];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
@@ -69,9 +71,9 @@ void main() {
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
#define tile_row(r) (row_tid * rows_per_thread + (r))
@@ -102,9 +104,9 @@ void main() {
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
ACC_TYPEV4 Of[rows_per_thread][d_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
Of[r][d] = ACC_TYPEV4(0.0);
}
}
@@ -125,13 +127,11 @@ void main() {
uint r = tid;
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
barrier();
} else {
if (tid < Br) {
uint r = tid;
slope[r] = 1.0;
slope[r] = ACC_TYPE(1.0);
}
barrier();
}
#if BLOCK_SIZE > 1
@@ -149,19 +149,45 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
float mask_cache[Bc * Br / WorkGroupSize];
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
if ((!KV_bounds_check || j * Bc + c < KV)) {
f16vec4 m;
if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
} else if (i * Br + r * 4 + 2 < p.nem1) {
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
0.0);
max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
} else if (i * Br + r * 4 + 1 < p.nem1) {
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
0.0,
0.0);
max_mask = max(max(max_mask, float(m[0])), float(m[1]));
} else if (i * Br + r * 4 < p.nem1) {
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
0.0,
0.0,
0.0);
max_mask = max(max_mask, float(m[0]));
} else {
m = f16vec4(0.0);
}
mask_cache[idx / WorkGroupSize] = m;
max_mask = max(max_mask, m);
}
}
}
@@ -180,26 +206,28 @@ void main() {
}
}
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
if (K_LOAD_SHMEM != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
}
ksh[c * kshstride + d] = K_Tf;
ksh[c * kshstride + d] = K_Tf;
}
}
barrier();
}
barrier();
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
@@ -208,11 +236,55 @@ void main() {
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
if (K_LOAD_SHMEM == 0) {
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK) {
#endif
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
uint32_t col_vec = (idx + tid) % (MatBr / 4);
uint32_t row = (idx + tid) / (MatBr / 4);
if (idx + tid < Bc * MatBr / 4) {
f16vec4 K_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
#endif
}
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
ksh[row * kshstride + col_vec] = K_Tf;
}
}
barrier();
#if BLOCK_SIZE == 1
}
#endif
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK)
#endif
{
uint coord = (gl_SubgroupID * MatBc) * kshstride;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
}
#if BLOCK_SIZE == 1
else {
const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
}
#endif
} else {
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
}
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
@@ -222,26 +294,26 @@ void main() {
barrier();
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / Br;
uint32_t r = (idx + tid) % Br;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
}
}
barrier();
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float f = mask_cache[idx / WorkGroupSize];
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
if (!KV_bounds_check || j * Bc + c < KV) {
// Mask nem1 bounds check is handled when loading masks
ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]);
ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]);
sfsh[c * sfshstride + r] += slopes * masks;
}
}
}
@@ -250,51 +322,145 @@ void main() {
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint r_vec = tile_row(r) / 4;
const uint r_comp = tile_row(r) % 4;
float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
}
float Moldf = Mf[r];
// Compute max across the row
rowmaxf = subgroupMax(rowmaxf);
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
float Pf[rows_per_thread];
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r];
Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
}
// Calculate and store Pf in Psh
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
const uint col = c * cols_per_iter + col_tid;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
const uint row = tile_row(r);
if (KV_bounds_check && j * Bc + col >= KV) {
Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
} else {
const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
[[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
Lf[r + vec_idx] += Pf[vec_idx];
}
Psh[col * psh_stride + row / 4] = Pf;
}
}
}
const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
// Each subgroup handles HSV/4 columns
[[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
// Preload V tiles for [Bc, 16 * num subgroups]
const uint v_rows = Bc;
const uint v_total = v_rows * v_cols;
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
#if BLOCK_SIZE == 1
// For f16, only preload if not aligned
if (KV_bounds_check) {
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
[[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
const uint idx = i * gl_WorkGroupSize.x + tid;
const uint row = idx / v_cols;
const uint col = idx % v_cols;
const uint v_row = j * Bc + row;
const uint v_col = hsv_tile * MatBc * row_split + col * 4;
const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
const uint ib = coord / BLOCK_SIZE;
const uint iqs = coord % BLOCK_SIZE;
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
#if BLOCK_SIZE > 1
ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
#else
ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
#endif
} else {
ksh[row * vsh_stride + col] = f16vec4(0.0f);
}
}
#if BLOCK_SIZE == 1
}
#endif
barrier();
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
#if BLOCK_SIZE == 1
if (!KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
#endif
{
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
// Store SfMat to sfsh and load into Of
const uint osh_stride = row_split * MatBc / 4;
const uint o_offset = gl_SubgroupID * MatBc / 4;
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
const uint hsv_per_tile = row_split * MatBc;
const uint hsv_base = hsv_tile * hsv_per_tile;
const uint d_values_per_tile = hsv_per_tile / 4;
const uint d_start = hsv_tile * d_values_per_tile;
const uint d_end = min(d_start + d_values_per_tile, HSV / 4);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
[[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) {
const uint d = d_local * threads_per_rowgroup + col_tid;
const uint hsv_col = 4 * d;
if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
const uint local_hsv = (hsv_col - hsv_base) / 4;
Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]);
}
}
}
}
@@ -302,69 +468,8 @@ void main() {
barrier();
}
// prevent race on tmpsh
barrier();
// reduce across threads
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE M = Mf[r];
tmpsh[tid] = M;
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
M = max(M, tmpsh[tid ^ s]);
barrier();
tmpsh[tid] = M;
barrier();
}
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
eMf[r] = exp(Moldf[r] - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE L = Lf[r];
tmpsh[tid] = L;
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
L += tmpsh[tid ^ s];
barrier();
tmpsh[tid] = L;
barrier();
}
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
Of[r][d] += tmpshv4[tid ^ s];
barrier();
tmpshv4[tid] = Of[r][d];
barrier();
}
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
Lf[r] = subgroupAdd(Lf[r]);
}
// If there is split_k, then the split_k resolve shader does the final
@@ -375,9 +480,12 @@ void main() {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
}
}
}
@@ -404,8 +512,9 @@ void main() {
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ACC_TYPE(ms);
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
Of[r][d_local] *= ACC_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
@@ -420,11 +529,12 @@ void main() {
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
Of[r][d_local] *= ACC_TYPE(Lfrcp[r]);
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX);
#endif
}
}
@@ -434,9 +544,12 @@ void main() {
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV / 4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
}
}
}
@@ -444,9 +557,12 @@ void main() {
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV / 4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
}
}
}
@@ -55,7 +55,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
return max(elem0, elem1);
}
#if defined(BLOCK_SIZE)
#if BLOCK_SIZE > 1
#define DECODEFUNC , DEQUANTFUNC
#else
#define DECODEFUNC
@@ -85,7 +85,7 @@ void main() {
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if defined(BLOCK_SIZE)
#if BLOCK_SIZE > 1
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
@@ -98,7 +98,7 @@ void main() {
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if !defined(BLOCK_SIZE)
#if BLOCK_SIZE == 1
k_stride &= ~7;
v_stride &= ~7;
#endif
+105 -60
View File
@@ -114,7 +114,7 @@ struct Params {
#define PARAMS_BINDING 4
#endif
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<f32>;
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
// Just a very small float value.
@@ -160,14 +160,21 @@ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
return v;
}
fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32) -> vec4<f32> {
return (*buf)[scalar_index >> 2u];
}
fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
return (*buf)[scalar_index >> 2u];
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
// initialize row max for online softmax
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
@@ -231,9 +238,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
// clear inter_shmem to ensure zero-initialized accumulators
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
inter_shmem[elem_idx] = 0.0;
}
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
inter_shmem[elem_idx] = 0.0;
}
// load k tile into shared memory
#if defined(KV_Q4_0)
@@ -309,48 +316,77 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
// accumulate q block * k block into registers across the entire KV tile
// TODO: this loop seems to be the current largest bottleneck
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
let inter_offset = kv_block * SG_MAT_N;
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<
subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
// this bracket exists to scope the lifetime of variables, reducing register pressure
{
#ifdef KV_DIRECT
let k_block_row = kv_tile + kv_block * SG_MAT_N;
let k_global_offset = k_head_offset + k_block_row * params.stride_k1;
let k_block_row = kv_tile + subgroup_id * SG_MAT_N;
var k_global_offset = k_head_offset + k_block_row * params.stride_k1;
#else
let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK;
var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK;
#endif
for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) {
// load q submatrix from shared memory
var q_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
&q_shmem,
head_dim_block,
false,
HEAD_DIM_QK
);
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
let inter_offset = kv_block * SG_MAT_N;
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
// load k submatrix from device or shared memory
#ifdef KV_DIRECT
var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
&K,
k_global_offset + head_dim_block,
true,
params.stride_k1
);
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
#else
var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
&kv_shmem,
k_block_offset + head_dim_block,
true,
HEAD_DIM_QK
);
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc);
var t: u32 = 1u;
for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
let h0 = t * SG_MAT_K;
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
#else
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = q0;
k_cur = k0;
let h1 = (t + 1u) * SG_MAT_K;
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
#else
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = q1g;
k_cur = k1g;
}
// handle odd tail
if (t < HEAD_DIM_QK / SG_MAT_K) {
let h = t * SG_MAT_K;
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
#ifdef KV_DIRECT
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
#else
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = qn;
k_cur = kn;
}
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
#ifdef KV_DIRECT
k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1;
#else
k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK;
#endif
subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
}
// store acc to shared memory for softmax (S matrix from paper)
subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
}
#ifdef MASK
// load mask tile into shared memory for this KV block
// TODO: optimize and skip if mask is -INF for the entire tile
@@ -495,7 +531,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
false,
HEAD_DIM_V
);
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
let p_offset = kv_block * SG_MAT_N;
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
@@ -527,11 +562,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
// O += P * V
o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
}
// store O back to shared memory
subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
}
workgroupBarrier();
}
@@ -566,26 +599,38 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
o_shmem[idx] = f16(val);
}
}
workgroupBarrier();
#endif
// write output back to global memory
for (var q_tile_row = subgroup_id;
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
break;
}
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
let exp_sum = exp_sum_shmem[q_tile_row];
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0);
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) { break; }
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx];
let scaled = f32(o_val) * scale;
dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled;
}
let exp_sum = exp_sum_shmem[q_tile_row];
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride;
for (var elem_base = sg_inv_id * 4u;
elem_base < HEAD_DIM_V;
elem_base += subgroup_size * 4u) {
let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
let v = vec4<f32>(
f32(o_shmem[i0]) * scale,
f32(o_shmem[i1]) * scale,
f32(o_shmem[i2]) * scale,
f32(o_shmem[i3]) * scale
);
let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
dst[dst_vec_index] = v;
}
}
}
+2 -3
View File
@@ -2,7 +2,6 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-cpu.h"
#include "zendnnl.hpp"
#include <cstring>
@@ -122,8 +121,8 @@ static void ggml_zendnn_compute_forward_mul_mat(
GGML_TENSOR_BINARY_OP_LOCALS
ggml_type const vec_dot_type = ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(vec_dot_type)->from_float;
ggml_type const vec_dot_type = src0->type;
ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
@@ -0,0 +1,156 @@
{#- ======== Template Parameters ======== #}
{%- set add_generation_prompt = add_generation_prompt if add_generation_prompt is defined else true %}
{%- set default_system_prompt = default_system_prompt if default_system_prompt is defined else true %}
{%- set reasoning_effort = reasoning_effort if reasoning_effort is defined else "high" %}
{%- set think_render_option = think_render_option if think_render_option is defined else "lastthink" %}
{#- ======== System Block State ======== #}
{%- set sys_ns = namespace(is_first_block=true) -%}
{#- ======== Find last user message index ======== #}
{%- set last_user_idx = namespace(value=-1) -%}
{%- for message in messages -%}
{%- if message.role == 'user' -%}
{%- set last_user_idx.value = loop.index0 -%}
{%- endif -%}
{%- endfor -%}
{#- ======== System messages renderers ======== #}
{%- macro render_system_message(user_system_messages) %}
{%- if default_system_prompt %}
{%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %}
{%- set sys_ns.is_first_block = false %}
{{- "## Provider System Prompt\n\nYou are Solar Open 100B, a large language model trained by Upstage AI, a Korean startup. Your knowledge cutoff is 2025-07. The current date is " + strftime_now("%Y-%m-%d") + "." }}
{%- endif -%}
{%- if user_system_messages %}
{%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %}
{%- set sys_ns.is_first_block = false %}
{{- "## System Prompt" }}
{%- for system_message in user_system_messages %}
{{- "\n\n" }}
{{- system_message }}
{%- endfor %}
{%- endif -%}
{%- endmacro %}
{%- macro render_tool_instruction(tools) %}
{%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %}
{%- set sys_ns.is_first_block = false %}
{{- "## Tools\n\n### Tool Call Instruction" }}
{{- "\nYou may invoke one or more tools to assist with the user's query. Available tools are provided in JSON Schema format: <|tools:begin|><|tool:begin|><tools-json-object><|tool:end|>...<|tools:end|>\n" }}
{{- "\n### Available Tools\n" }}
{{- "<|tools:begin|>" }}
{%- for tool in tools %}
{{- "<|tool:begin|>" }}
{{- tool.function | tojson }}
{{- "<|tool:end|>" }}
{%- endfor %}
{{- "<|tools:end|>\n" }}
{{- "\n### Tool Call Format\n" }}
{{- "For each tool call, return a JSON object with the following structure, enclosed within <|tool_call:begin|> and <|tool_call:end|> tags: \n<|tool_call:begin|><tool-call-id><|tool_call:name|><tool-name><|tool_call:args|><args-json-object><|tool_call:end|>\n" }}
{{- "- The <tool-call-id> must be a randomly generated string consisting of 10 lowercase letters (a-z) and/or digits (0-9) (e.g., a1b2c3d4e5)\n" }}
{{- "\n### Tool Response Format\n" }}
{{- "Each tool is responded by `tool` with the following structure:\n<|tool_response:id|><tool-call-id><|tool_response:name|><tool-name><|tool_response:result|><results><|tool_response:end|>\n" }}
{{- "- Ensure the <tool-call-id> matches the corresponding tool call" -}}
{%- endmacro %}
{%- macro render_json_response_format_instruction(response_format) %}
{%- if not sys_ns.is_first_block %}{{- "\n\n" }}{%- endif %}
{%- set sys_ns.is_first_block = false %}
{{- "## Output Format Constraint" }}
{{- "\n\nYour final response should follow the JSON schema: \n[Start of schema]" }}
{{- response_format }}
{{- "\n[End of schema]\nPlease ensure your answers adhere to this format and do not contain any unnecessary text." }}
{%- endmacro %}
{%- macro get_tool_name(messages, tool_call_id) %}
{%- for msg in messages -%}
{%- if msg.role == 'assistant' and msg.tool_calls -%}
{%- for tool_call in msg.tool_calls -%}
{%- if tool_call.id == tool_call_id -%}
{{- tool_call.function.name }}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- endfor -%}
{%- endmacro %}
{%- macro render_tool_arguments(tool_arguments) %}
{%- if tool_arguments is mapping -%}
{{- tool_arguments | tojson }}
{%- else -%}
{{- tool_arguments }}
{%- endif -%}
{%- endmacro %}
{#- ======== Render system message ======== #}
{%- set ns = namespace(system_messages=[]) -%}
{%- for message in messages -%}
{%- if message.role == 'system' -%}
{%- set ns.system_messages = ns.system_messages + [message.content] -%}
{%- endif -%}
{%- endfor -%}
{%- if ns.system_messages or default_system_prompt or tools or response_format -%}
{{- "<|begin|>system<|content|>" }}
{{- render_system_message(ns.system_messages) }}
{%- if tools -%}
{{- render_tool_instruction(tools) }}
{%- endif %}
{%- if response_format -%}
{{- render_json_response_format_instruction(response_format) }}
{%- endif %}
{{- "<|end|>" }}
{%- endif -%}
{#- ======== Render main messages ======== #}
{%- for message in messages -%}
{%- if message.role == 'user' -%}
{{- "<|begin|>user<|content|>" + message.content + "<|end|>" }}
{%- elif message.role == 'tool' -%}
{%- set prev_is_tool = loop.index0 > 0 and messages[loop.index0 - 1].role == 'tool' -%}
{%- set next_is_tool = loop.index0 < (messages | length - 1) and messages[loop.index0 + 1].role == 'tool' -%}
{%- if not prev_is_tool -%}
{{- "<|begin|>tool<|tool_response|>" }}
{%- endif -%}
{{- "<|tool_response:begin|>" + message.tool_call_id + "<|tool_response:name|>" }}
{{- get_tool_name(messages, message.tool_call_id) }}
{{- "<|tool_response:result|>" }}
{{- message.content }}
{{- "<|tool_response:end|>" }}
{%- if not next_is_tool -%}
{{- "<|end|>" }}
{%- endif -%}
{%- elif message.role == 'assistant' -%}
{#- ======== Assistant Thinking ======== #}
{%- if think_render_option == "all" -%}
{%- if message.reasoning -%}
{{- "<|begin|>assistant<|think|>" + message.reasoning + "<|end|>" }}
{%- endif -%}
{%- elif think_render_option == "lastthink" -%}
{%- if message.reasoning and loop.index0 > last_user_idx.value -%}
{{- "<|begin|>assistant<|think|>" + message.reasoning + "<|end|>" }}
{%- endif -%}
{%- endif -%}
{#- ======== Assistant Messages ======== #}
{%- if message.tool_calls -%}
{{- "<|begin|>assistant<|tool_calls|>" }}
{%- for tool_call in message.tool_calls -%}
{{- "<|tool_call:begin|>" + tool_call.id +"<|tool_call:name|>" + tool_call.function.name + "<|tool_call:args|>" }}
{{- render_tool_arguments(tool_call.function.arguments) }}
{{- "<|tool_call:end|>" }}
{%- endfor -%}
{{- "<|calls|>" }}
{%- else -%}
{{- "<|begin|>assistant<|content|>" + message.content + "<|end|>" }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if reasoning_effort in ["low", "minimal"] -%}
{{- "<|begin|>assistant<|think|><|end|>" }}
{%- endif -%}
{{- "<|begin|>assistant" }}
{%- endif -%}
+40
View File
@@ -0,0 +1,40 @@
#!/usr/bin/env pwsh
# Basedir on device
$basedir=".\pkg-snapdragon"
$cli_opts=$args
$model="Llama-3.2-3B-Instruct-Q4_0.gguf"
if ($null -ne $env:M) {
$model=$env:M
}
$device="HTP0"
if ($null -ne $env:D) {
$device=$env:D
}
if ($null -ne $env:V) {
$env:GGML_HEXAGON_VERBOSE=$env:V
}
if ($null -ne $env:OPMASK) {
$env:GGML_HEXAGON_OPMASK=$env:OPMASK
}
if ($null -ne $env:NHVX) {
$env:GGML_HEXAGON_NHVX=$env:NHVX
}
if ($null -ne $env:NDEV) {
$env:GGML_HEXAGON_NDEV=$env:NDEV
}
$env:ADSP_LIBRARY_PATH="$basedir\lib"
& "$basedir\bin\llama-bench.exe" `
--mmap 0 -m $basedir\..\..\gguf\$model `
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
--batch-size 128 -ngl 99 --device $device $cli_opts
+53
View File
@@ -0,0 +1,53 @@
#!/usr/bin/env pwsh
# Basedir on device
$basedir=".\pkg-snapdragon"
$cli_opts=$args
$model="Llama-3.2-3B-Instruct-Q4_0.gguf"
if ($null -ne $env:M) {
$model=$env:M
}
$device="HTP0"
if ($null -ne $env:D) {
$device=$env:D
}
if ($null -ne $env:V) {
$env:GGML_HEXAGON_VERBOSE=$env:V
}
if ($null -ne $env:E) {
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
}
if ($null -ne $env:SCHED) {
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
}
if ($null -ne $env:PROF) {
$env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1
}
if ($null -ne $env:OPMASK) {
$env:GGML_HEXAGON_OPMASK=$env:OPMASK
}
if ($null -ne $env:NHVX) {
$env:GGML_HEXAGON_NHVX=$env:NHVX
}
if ($null -ne $env:NDEV) {
$env:GGML_HEXAGON_NDEV=$env:NDEV
}
$env:ADSP_LIBRARY_PATH="$basedir\lib"
& "$basedir\bin\llama-completion.exe" `
--no-mmap -no-cnv -m $basedir\..\..\gguf\$model `
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on `
-ngl 99 --device $device $cli_opts
+56
View File
@@ -0,0 +1,56 @@
#!/usr/bin/env pwsh
# Basedir on device
$basedir=".\pkg-snapdragon"
if ($args.Count -eq 0) {
Write-Host "No arguments provided.Expected the tool and argument to run."
exit -1
}
$tool=$args[0]
$cli_opts=@()
if ($args.Count -gt 1) {
$cli_opts=$args[1..($args.Count - 1)]
$remainingArgs = $args[1..($args.Count - 1)]
}
$device="HTP0"
if ($null -ne $env:D) {
$device=$env:D
}
if ($null -ne $env:V) {
$env:GGML_HEXAGON_VERBOSE=$env:V
}
if ($null -ne $env:E) {
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
}
if ($null -ne $env:SCHED) {
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
}
if ($null -ne $env:PROF) {
$env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1
}
if ($null -ne $env:OPMASK) {
$env:GGML_HEXAGON_OPMASK=$env:OPMASK
}
if ($null -ne $env:NHVX) {
$env:GGML_HEXAGON_NHVX=$env:NHVX
}
if ($null -ne $env:NDEV) {
$env:GGML_HEXAGON_NDEV=$env:NDEV
}
$env:ADSP_LIBRARY_PATH="$basedir\lib"
& "$basedir\bin\$tool" `
$cli_opts
+105
View File
@@ -0,0 +1,105 @@
# Requires Run as Administrator is NOT strictly necessary for User-scope env vars,
# but recommended for creating directories in C:\ root if permissions are restricted.
$ErrorActionPreference = "Stop"
# --- Configuration ---
$BaseDir = "C:\Qualcomm"
# SDK 1: Hexagon
$HexagonUrl = "https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v6.4.0.2/hexagon-sdk-v6.4.0.2-arm64-wos.tar.xz"
$HexagonParent = Join-Path $BaseDir "Hexagon_SDK"
$HexagonSdkVersion = "6.4.0.2"
$HexagonToolsVersion = "19.0.04"
$HexagonSdkTarget = Join-Path $HexagonParent $HexagonSdkVersion
$HexagonToolsTarget = Join-Path $HexagonSdkTarget "\tools\HEXAGON_Tools\$HexagonToolsVersion"
# SDK 2: OpenCL
$OpenCLUrl = "https://github.com/snapdragon-toolchain/opencl-sdk/releases/download/v2.3.2/adreno-opencl-sdk-v2.3.2-arm64-wos.tar.xz"
$OpenCLParent = Join-Path $BaseDir "OpenCL_SDK"
$OpenCLVersion = "2.3.2"
$OpenCLTarget = Join-Path $OpenCLParent $OpenCLVersion
# --- Helper Function ---
function Install-QualcommSDK {
param (
[string]$Url,
[string]$ParentDir,
[string]$TargetDir,
[string]$Name
)
# 1. Create Parent Directory
if (-not (Test-Path -Path $ParentDir)) {
Write-Host "Creating directory: $ParentDir" -ForegroundColor Cyan
New-Item -Path $ParentDir -ItemType Directory -Force | Out-Null
}
# 2. Check for Specific Version Directory
if (Test-Path -Path $TargetDir) {
Write-Host "$Name ($TargetDir) already exists. Skipping download." -ForegroundColor Green
}
else {
Write-Host "$Name not found. preparing to download..." -ForegroundColor Yellow
# Create the target directory to extract into
New-Item -Path $TargetDir -ItemType Directory -Force | Out-Null
# Define temporary archive path
$TempFile = Join-Path $ParentDir "temp_sdk.tar.xz"
try {
# Download
Write-Host "Downloading from: $Url"
Invoke-WebRequest -Uri $Url -OutFile $TempFile
# Untar
# Note: We assume Windows includes tar.exe (Win 10 build 17063+)
Write-Host "Extracting archive to $TargetDir..."
# We use -C to extract contents INTO the target directory created above
tar -xJvf $TempFile -C $TargetDir\..
Write-Host "Extraction complete." -ForegroundColor Green
}
catch {
Write-Error "Failed to download or extract $Name. Error: $_"
# Cleanup target dir if failed so script tries again next time
Remove-Item -Path $TargetDir -Recurse -Force -ErrorAction SilentlyContinue
}
finally {
# Cleanup Archive
if (Test-Path $TempFile) { Remove-Item $TempFile -Force }
}
}
}
# --- Execution ---
# 1. Ensure Base C:\Qualcomm exists
if (-not (Test-Path $BaseDir)) {
New-Item -Path $BaseDir -ItemType Directory -Force | Out-Null
}
# 2. Run Install Logic
Install-QualcommSDK -Url $HexagonUrl -ParentDir $HexagonParent -TargetDir $HexagonSdkTarget -Name "Hexagon SDK"
Install-QualcommSDK -Url $OpenCLUrl -ParentDir $OpenCLParent -TargetDir $OpenCLTarget -Name "OpenCL SDK"
# --- Environment Variables ---
Write-Host "`nSetting Environment Variables..." -ForegroundColor Cyan
# Set OPENCL_SDK_ROOT
[System.Environment]::SetEnvironmentVariable('OPENCL_SDK_ROOT', $OpenCLTarget, [System.EnvironmentVariableTarget]::User)
$env:OPENCL_SDK_ROOT = $OpenCLTarget # Set for current session as well
Write-Host "OPENCL_SDK_ROOT set to: $OpenCLTarget"
# Set HEXAGON_SDK_ROOT
[System.Environment]::SetEnvironmentVariable('HEXAGON_SDK_ROOT', $HexagonSdkTarget, [System.EnvironmentVariableTarget]::User)
$env:HEXAGON_SDK_ROOT = $HexagonSdkTarget # Set for current session as well
Write-Host "HEXAGON_SDK_ROOT set to: $HexagonSdkTarget"
# Set HEXAGON_SDK_ROOT
[System.Environment]::SetEnvironmentVariable('HEXAGON_TOOLS_ROOT', $HexagonToolsTarget, [System.EnvironmentVariableTarget]::User)
$env:HEXAGON_TOOLS_ROOT = $HexagonToolsTarget # Set for current session as well
Write-Host "HEXAGON_TOOLS_ROOT set to: $HexagonToolsTarget"
-5
View File
@@ -1630,11 +1630,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
if (!cparams.offload_kqv) {
// all nodes between the KV store and the attention output are run on the CPU
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
}
ggml_flash_attn_ext_add_sinks(cur, sinks);
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
+2 -3
View File
@@ -54,7 +54,6 @@ std::string DEFAULT_JSON = R"({
],
"bos_token": "<s>",
"eos_token": "</s>",
"tools": [],
"add_generation_prompt": true
})";
@@ -481,7 +480,7 @@ int main_automated_tests(void) {
/* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
/* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[AVAILABLE_TOOLS] [[/AVAILABLE_TOOLS][INST] You are a helpful assistant\n\nAnother question[/INST]",
/* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] You are a helpful assistant\n\nAnother question[/INST]",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
@@ -489,7 +488,7 @@ int main_automated_tests(void) {
/* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
/* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[AVAILABLE_TOOLS][[/AVAILABLE_TOOLS][INST]You are a helpful assistant\n\nAnother question[/INST]",
/* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]You are a helpful assistant\n\nAnother question[/INST]",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
+129 -1
View File
@@ -592,7 +592,7 @@ static void test_peg_parser(common_chat_templates * tmpls, const std::function<v
}
if (diff.tool_call_index != std::string::npos) {
if (!diff.tool_call_delta.name.empty()) {
msg_accum.tool_calls.push_back({diff.tool_call_delta.name, "", ""});
msg_accum.tool_calls.push_back({diff.tool_call_delta.name, "", diff.tool_call_delta.id});
}
if (!diff.tool_call_delta.arguments.empty()) {
msg_accum.tool_calls.back().arguments += diff.tool_call_delta.arguments;
@@ -3799,6 +3799,134 @@ static void test_template_output_peg_parsers() {
});
}
{
// Solar-Open-100B
auto tmpls = read_templates("models/templates/upstage-Solar-Open-100B.jinja");
// Test basic message
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|content|>Hello, world!\nWhat's up?";
t.expect = message_assist;
});
// Test basic message and reasoning
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|think|>I'm\nthinking<|end|><|begin|>assistant<|content|>Hello, world!\nWhat's up?";
t.expect = message_assist_thoughts;
});
// Test basic message and reasoning_effort = low
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|content|>Hello, world!\nWhat's up?";
t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
t.expect = message_assist;
});
// Test tool call
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|tool_calls|>"
"<|tool_call:begin|>123456789"
"<|tool_call:name|>special_function"
"<|tool_call:args|>{\"arg1\":1}"
"<|tool_call:end|>";
t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
t.params.tools = {special_function_tool};
t.expect = message_assist_call_id;
});
// Test tool call with reasoning
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|think|>I'm\nthinking<|end|>"
"<|begin|>assistant<|tool_calls|>"
"<|tool_call:begin|>0"
"<|tool_call:name|>special_function"
"<|tool_call:args|>{\"arg1\":1}"
"<|tool_call:end|>";
t.params.tools = {special_function_tool};
t.expect = message_assist_thoughts_call_idx;
});
// Test tool call with reasoning and tool_choice = required
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|think|>I'm\nthinking<|end|>"
"<|begin|>assistant<|tool_calls|>"
"<|tool_call:begin|>0"
"<|tool_call:name|>special_function"
"<|tool_call:args|>{\"arg1\":1}"
"<|tool_call:end|>";
t.params.tools = {special_function_tool};
t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED;
t.expect = message_assist_thoughts_call_idx;
});
// Test tool call without reasoning and tool_choice = required
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|tool_calls|>"
"<|tool_call:begin|>0"
"<|tool_call:name|>special_function"
"<|tool_call:args|>{\"arg1\":1}"
"<|tool_call:end|>";
t.params.tools = {special_function_tool};
t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED;
t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
t.expect = message_assist_call_idx;
});
// Test parallel tool calls
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|think|>I'm\nthinking<|end|>"
"<|begin|>assistant<|tool_calls|>"
"<|tool_call:begin|>0"
"<|tool_call:name|>special_function"
"<|tool_call:args|>{\"arg1\":1}"
"<|tool_call:end|>"
"<|tool_call:begin|>1"
"<|tool_call:name|>special_function_with_opt"
"<|tool_call:args|>{\"arg1\": 1, \"arg2\": 2}"
"<|tool_call:end|>";
t.params.parallel_tool_calls = true;
t.params.tools = {special_function_tool, special_function_tool_with_optional_param};
t.expect.reasoning_content = "I'm\nthinking";
t.expect.tool_calls = {{
/* .name = */ "special_function",
/* .arguments = */ R"({"arg1": 1})",
/* .id = */ "0",
}, {
/* .name = */ "special_function_with_opt",
/* .arguments = */ R"({"arg1": 1, "arg2": 2})",
/* .id = */ "1",
}};
});
// Test response format
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|think|>I need to output the invoice details in JSON<|end|>"
"<|begin|>assistant<|content|>"
R"({"amount": 123.45, "date": "2025-12-03"})";
t.params.json_schema = invoice_schema;
t.expect.reasoning_content = "I need to output the invoice details in JSON";
t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})";
});
// Test response format no reasoning
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "<|content|>"
R"({"amount": 123.45, "date": "2025-12-03"})";
t.params.chat_template_kwargs["reasoning_effort"] = "\"low\"";
t.params.json_schema = invoice_schema;
t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})";
});
}
}
static void test_msg_diffs_compute() {
Binary file not shown.
+63 -83
View File
@@ -48,11 +48,8 @@ enum server_state {
struct server_slot {
int id;
llama_batch batch_spec = {};
// TODO: change to unique_ptrs for consistency:
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
// multimodal
mtmd_context * mctx = nullptr;
@@ -259,7 +256,7 @@ struct server_slot {
}
bool can_speculate() const {
return ctx_dft;
return !!spec;
}
void add_token(const completion_token_output & token) {
@@ -295,6 +292,7 @@ struct server_slot {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
n_draft_max = 0;
}
return n_draft_max;
}
@@ -397,6 +395,8 @@ struct server_slot {
draft_ratio, n_draft_accepted, n_draft_total
);
}
common_speculative_print_stats(spec);
}
json to_json(bool only_metrics = false) const {
@@ -553,18 +553,13 @@ private:
// note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result_ptr llama_init;
common_init_result_ptr llama_init_dft;
llama_context * ctx = nullptr;
bool vocab_dft_compatible = true;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
llama_batch batch {};
llama_model_ptr model_dft;
bool add_bos_token = true;
int32_t n_ctx; // total context for all clients / slots
@@ -597,13 +592,8 @@ private:
// Clear any sampling context
for (server_slot & slot : slots) {
llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr;
common_speculative_free(slot.spec);
slot.spec = nullptr;
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
@@ -648,44 +638,39 @@ private:
add_bos_token = llama_vocab_get_add_bos(vocab);
if (params_base.has_speculative()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
if (params_base.speculative.has_dft()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
const auto & params_spec = params_base.speculative;
auto params_dft = params_base;
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
params_dft.cache_type_k = params_base.speculative.cache_type_k;
params_dft.cache_type_v = params_base.speculative.cache_type_v;
params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams_dft;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
params_dft.cache_type_k = params_spec.cache_type_k;
params_dft.cache_type_v = params_spec.cache_type_v;
params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads;
params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads;
params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides;
if (params_spec.cpuparams.n_threads > 0) {
params_dft.cpuparams.n_threads = params_spec.cpuparams.n_threads;
params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads;
}
llama_init_dft = common_init_from_params(params_dft);
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
model_dft = llama_init_dft->model();
auto mparams_dft = common_model_params_to_llama(params_dft);
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
return false;
}
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
if (!vocab_dft_compatible) {
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
}
const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft;
// the context is not needed - we will create one for each slot
llama_init_dft->free_context();
params_base.speculative.model_dft = model_dft.get();
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
}
std::string & mmproj_path = params_base.mmproj.path;
@@ -695,6 +680,7 @@ private:
}
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
@@ -702,6 +688,7 @@ private:
mparams.warmup = params_base.warmup;
mparams.image_min_tokens = params_base.image_min_tokens;
mparams.image_max_tokens = params_base.image_max_tokens;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
if (mctx == nullptr) {
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
@@ -718,11 +705,6 @@ private:
params_base.n_cache_reuse = 0;
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
}
if (params_base.has_speculative()) {
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
return false;
}
}
if (!llama_memory_can_shift(llama_get_memory(ctx))) {
@@ -757,29 +739,24 @@ private:
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
slot.id = i;
slot.ctx = ctx;
slot.id = i;
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.mctx = mctx;
slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return false;
}
slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
if (slot.spec == nullptr) {
SRV_ERR("%s", "failed to create speculator\n");
return false;
}
for (auto & pair : params_base.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
// try speculative decoding
{
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return false;
}
SRV_WRN("%s", "speculative decoding context initialized\n");
} else {
SRV_WRN("%s", "speculative decoding context not initialized\n");
}
}
@@ -1059,7 +1036,7 @@ private:
return res;
}
std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) {
std::vector<common_adapter_lora_info> construct_lora_list(const std::map<int, float> & config) const {
std::vector<common_adapter_lora_info> output = params_base.lora_adapters; // copy
for (size_t i = 0; i < output.size(); ++i) {
auto it = config.find(i);
@@ -1162,7 +1139,7 @@ private:
backend_sampling &= task.params.sampling.backend_sampling;
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0);
backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
@@ -1179,14 +1156,6 @@ private:
slot.smpl.reset();
}
// initialize draft batch
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
}
slot.task = std::make_unique<const server_task>(std::move(task));
slot.state = slot.task->is_child()
@@ -2059,19 +2028,23 @@ private:
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
int n_draft_max = slot.get_n_draft_max();
const int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
const auto & params_spec = slot.task->params.speculative;
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
if (draft.size() > (size_t) n_draft_max) {
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
draft.resize(n_draft_max);
}
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
@@ -2742,6 +2715,10 @@ private:
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
if (slot.can_speculate()) {
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
@@ -2813,6 +2790,9 @@ private:
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
// inform the speculative decoding about the number of accepted tokens
common_speculative_accept(slot.spec, ids.size() - 1);
// rollback to the state before sampling the draft tokens
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
+23
View File
@@ -5,6 +5,7 @@
#include "llama.h"
#include "chat.h"
#include "sampling.h"
#include "speculative.h"
#include "json-schema-to-grammar.h"
using json = nlohmann::ordered_json;
@@ -76,6 +77,11 @@ json task_params::to_json(bool only_metrics) const {
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.ngram_size_n", speculative.ngram_size_n},
{"speculative.ngram_size_m", speculative.ngram_size_m},
{"speculative.ngram_c_rate", speculative.ngram_check_rate},
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -135,6 +141,11 @@ json task_params::to_json(bool only_metrics) const {
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min},
{"speculative.type", common_speculative_type_to_str(speculative.type)},
{"speculative.ngram_size_n", speculative.ngram_size_n},
{"speculative.ngram_size_m", speculative.ngram_size_m},
{"speculative.ngram_c_rate", speculative.ngram_check_rate},
{"speculative.ngram_m_hits", speculative.ngram_min_hits},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
@@ -242,6 +253,18 @@ task_params server_task::params_from_json_cmpl(
params.speculative.n_min = std::max(params.speculative.n_min, 0);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n);
params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m);
params.speculative.ngram_check_rate = json_value(data, "speculative.ngram_c_rate", defaults.speculative.ngram_check_rate);
params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits);
params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024);
params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024);
params.speculative.ngram_check_rate = std::max(std::min(1, (int) params.speculative.ngram_check_rate), 1024);
params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024);
// Use OpenAI API logprobs only if n_probs wasn't provided
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
+42 -17
View File
@@ -61,7 +61,7 @@
"remark-math": "^6.0.0",
"sass": "^1.93.3",
"storybook": "^10.0.7",
"svelte": "^5.0.0",
"svelte": "^5.38.2",
"svelte-check": "^4.0.0",
"tailwind-merge": "^3.3.1",
"tailwind-variants": "^3.2.2",
@@ -88,6 +88,7 @@
"version": "2.3.0",
"resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz",
"integrity": "sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"@jridgewell/gen-mapping": "^0.3.5",
@@ -867,6 +868,7 @@
"integrity": "sha512-oJrXtQiAXLvT9clCf1K4kxp3eKsQhIaZqxEyowkBcsvZDdZkbWrVmnGknxs5flTD0VGsxrxKgBCZty1EzoiMzA==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"dependencies": {
"@swc/helpers": "^0.5.0"
}
@@ -898,7 +900,6 @@
"version": "2.3.5",
"resolved": "https://registry.npmjs.org/@jridgewell/remapping/-/remapping-2.3.5.tgz",
"integrity": "sha512-LI9u/+laYG4Ds1TDKSJW2YPrIlcVYOwi2fUC6xB43lueCjgxV4lffOCZCtYFiH6TNOX+tQKXx97T4IKHbhyHEQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"@jridgewell/gen-mapping": "^0.3.5",
@@ -2031,6 +2032,7 @@
"integrity": "sha512-rO+YQhHucy47Vh67z318pALmd6x+K1Kj30Fb4a6oOEw4xn4zCo9KTmkMWs24c4oduEXD/eJu3badlRmsVXzyfA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"ts-dedent": "^2.0.0",
"type-fest": "~2.19"
@@ -2114,6 +2116,7 @@
"integrity": "sha512-Vp3zX/qlwerQmHMP6x0Ry1oY7eKKRcOWGc2P59srOp4zcqyn+etJyQpELgOi4+ZSUgteX8Y387NuwruLgGXLUQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -2153,6 +2156,7 @@
"integrity": "sha512-YZs/OSKOQAQCnJvM/P+F1URotNnYNeU3P2s4oIpzm1uFaqUEqRxUB0g5ejMjEb5Gjb9/PiBI5Ktrq4rUUF8UVQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^5.0.0",
"debug": "^4.4.1",
@@ -2568,6 +2572,7 @@
"integrity": "sha512-pemlzrSESWbdAloYml3bAJMEfNh1Z7EduzqPKprCH5S341frlpYnUEW0H72dLxa6IsYr+mPno20GiSm+h9dEdQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@babel/code-frame": "^7.10.4",
"@babel/runtime": "^7.12.5",
@@ -2735,6 +2740,7 @@
"integrity": "sha512-bJFoMATwIGaxxx8VJPeM8TonI8t579oRvgAuT8zFugJsJZgzqv0Fu8Mhp68iecjzG7cnN3mO2dJQ5uUM2EFrgQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -2802,6 +2808,7 @@
"integrity": "sha512-kVIaQE9vrN9RLCQMQ3iyRlVJpTiDUY6woHGb30JDkfJErqrQEmtdWH3gV0PBAfGZgQXoqzXOO0T3K6ioApbbAA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@typescript-eslint/scope-manager": "8.37.0",
"@typescript-eslint/types": "8.37.0",
@@ -3026,6 +3033,7 @@
"integrity": "sha512-tJxiPrWmzH8a+w9nLKlQMzAKX/7VjFs50MWgcAj7p9XQ7AQ9/35fByFYptgPELyLw+0aixTnC4pUWV+APcZ/kw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@testing-library/dom": "^10.4.0",
"@testing-library/user-event": "^14.6.1",
@@ -3129,6 +3137,7 @@
"integrity": "sha512-oukfKT9Mk41LreEW09vt45f8wx7DordoWUZMYdY/cyAk7w5TWkTRCNZYF7sX7n2wB7jyGAl74OxgwhPgKaqDMQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@vitest/utils": "3.2.4",
"pathe": "^2.0.3",
@@ -3186,6 +3195,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -3738,8 +3748,7 @@
"resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz",
"integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==",
"dev": true,
"license": "MIT",
"peer": true
"license": "MIT"
},
"node_modules/debug": {
"version": "4.4.1",
@@ -3840,10 +3849,9 @@
}
},
"node_modules/devalue": {
"version": "5.3.2",
"resolved": "https://registry.npmjs.org/devalue/-/devalue-5.3.2.tgz",
"integrity": "sha512-UDsjUbpQn9kvm68slnrs+mfxwFkIflOhkanmyabZ8zOYk8SMEIbJ3TK+88g70hSIeytu4y18f0z/hYHMTrXIWw==",
"dev": true,
"version": "5.6.2",
"resolved": "https://registry.npmjs.org/devalue/-/devalue-5.6.2.tgz",
"integrity": "sha512-nPRkjWzzDQlsejL1WVifk5rvcFi/y1onBRxjaFMjZeR9mFpqu2gmAZ9xUB9/IEanEP/vBtGeGganC/GO1fmufg==",
"license": "MIT"
},
"node_modules/devlop": {
@@ -3973,6 +3981,7 @@
"dev": true,
"hasInstallScript": true,
"license": "MIT",
"peer": true,
"bin": {
"esbuild": "bin/esbuild"
},
@@ -4027,6 +4036,7 @@
"integrity": "sha512-QldCVh/ztyKJJZLr4jXNUByx3gR+TDYZCRXEktiZoUR3PGy4qCmSbkxcIle8GEwGpb5JBZazlaJ/CxLidXdEbQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@eslint-community/eslint-utils": "^4.2.0",
"@eslint-community/regexpp": "^4.12.1",
@@ -6939,6 +6949,7 @@
}
],
"license": "MIT",
"peer": true,
"dependencies": {
"nanoid": "^3.3.11",
"picocolors": "^1.1.1",
@@ -7072,6 +7083,7 @@
"integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -7088,6 +7100,7 @@
"integrity": "sha512-pn1ra/0mPObzqoIQn/vUTR3ZZI6UuZ0sHqMK5x2jMLGrs53h0sXhkVuDcrlssHwIMk7FYrMjHBPoUSyyEEDlBQ==",
"dev": true,
"license": "MIT",
"peer": true,
"peerDependencies": {
"prettier": "^3.0.0",
"svelte": "^3.2.0 || ^4.0.0-next.0 || ^5.0.0-next.0"
@@ -7312,6 +7325,7 @@
"integrity": "sha512-FS+XFBNvn3GTAWq26joslQgWNoFu08F4kl0J4CgdNKADkdSGXQyTCnKteIAJy96Br6YbpEU1LSzV5dYtjMkMDg==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=0.10.0"
}
@@ -7322,6 +7336,7 @@
"integrity": "sha512-Xs1hdnE+DyKgeHJeJznQmYMIBG3TKIHJJT95Q58nHLSrElKlGQqDTR2HQ9fx5CN/Gk6Vh/kupBTDLU11/nDk/g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"scheduler": "^0.26.0"
},
@@ -7598,6 +7613,7 @@
"integrity": "sha512-4iya7Jb76fVpQyLoiVpzUrsjQ12r3dM7fIVz+4NwoYvZOShknRmiv+iu9CClZml5ZLGb0XMcYLutK6w9tgxHDw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@types/estree": "1.0.8"
},
@@ -7704,6 +7720,7 @@
"integrity": "sha512-elOcIZRTM76dvxNAjqYrucTSI0teAF/L2Lv0s6f6b7FOwcwIuA357bIE871580AjHJuSvLIRUosgV+lIWx6Rgg==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"chokidar": "^4.0.0",
"immutable": "^5.0.2",
@@ -7938,6 +7955,7 @@
"integrity": "sha512-7smAu0o+kdm378Q2uIddk32pn0UdIbrtTVU+rXRVtTVTCrK/P2cCui2y4JH+Bl3NgEq1bbBQpCAF/HKrDjk2Qw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@storybook/global": "^5.0.0",
"@storybook/icons": "^1.6.0",
@@ -8079,12 +8097,13 @@
}
},
"node_modules/svelte": {
"version": "5.36.12",
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.36.12.tgz",
"integrity": "sha512-c3mWT+b0yBLl3gPGSHiy4pdSQCsPNTjLC0tVoOhrGJ6PPfCzD/RQpAmAfJtQZ304CAae2ph+L3C4aqds3R3seQ==",
"version": "5.48.3",
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.48.3.tgz",
"integrity": "sha512-w7QZ398cdNherTdiQ/v3SYLLGOO4948Jgjh04PYqtTYVohmBvbmFwLmo7pp8gp4/1tceRWfSTjHgjtfpCVNJmQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@ampproject/remapping": "^2.3.0",
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
"@sveltejs/acorn-typescript": "^1.0.5",
"@types/estree": "^1.0.5",
@@ -8092,8 +8111,9 @@
"aria-query": "^5.3.1",
"axobject-query": "^4.1.0",
"clsx": "^2.1.1",
"devalue": "^5.6.2",
"esm-env": "^1.2.1",
"esrap": "^2.1.0",
"esrap": "^2.2.1",
"is-reference": "^3.0.3",
"locate-character": "^3.0.0",
"magic-string": "^0.30.11",
@@ -8281,9 +8301,9 @@
}
},
"node_modules/svelte/node_modules/esrap": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/esrap/-/esrap-2.1.0.tgz",
"integrity": "sha512-yzmPNpl7TBbMRC5Lj2JlJZNPml0tzqoqP5B1JXycNUwtqma9AKCO0M2wHrdgsHcy1WRW7S9rJknAMtByg3usgA==",
"version": "2.2.2",
"resolved": "https://registry.npmjs.org/esrap/-/esrap-2.2.2.tgz",
"integrity": "sha512-zA6497ha+qKvoWIK+WM9NAh5ni17sKZKhbS5B3PoYbBvaYHZWoS33zmFybmyqpn07RLUxSmn+RCls2/XF+d0oQ==",
"license": "MIT",
"dependencies": {
"@jridgewell/sourcemap-codec": "^1.4.15"
@@ -8326,6 +8346,7 @@
"integrity": "sha512-gBXpgUm/3rp1lMZZrM/w7D8GKqshif0zAymAhbCyIt8KMe+0v9DQ7cdYLR4FHH/cKpdTXb+A/tKKU3eolfsI+g==",
"dev": true,
"license": "MIT",
"peer": true,
"funding": {
"type": "github",
"url": "https://github.com/sponsors/dcastil"
@@ -8356,7 +8377,8 @@
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.11.tgz",
"integrity": "sha512-2E9TBm6MDD/xKYe+dvJZAmg3yxIEDNRc0jwlNyDg/4Fil2QcSLjFKGVff0lAf1jjeaArlG/M75Ey/EYr/OJtBA==",
"dev": true,
"license": "MIT"
"license": "MIT",
"peer": true
},
"node_modules/tapable": {
"version": "2.2.2",
@@ -8569,6 +8591,7 @@
"integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -8934,6 +8957,7 @@
"integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.5.0",
@@ -9094,6 +9118,7 @@
"integrity": "sha512-LUCP5ev3GURDysTWiP47wRRUpLKMOfPh+yKTx3kVIEiu5KOMeqzpnYNsKyOoVrULivR8tLcks4+lga33Whn90A==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@types/chai": "^5.2.2",
"@vitest/expect": "3.2.4",
+1 -1
View File
@@ -62,7 +62,7 @@
"remark-math": "^6.0.0",
"sass": "^1.93.3",
"storybook": "^10.0.7",
"svelte": "^5.0.0",
"svelte": "^5.38.2",
"svelte-check": "^4.0.0",
"tailwind-merge": "^3.3.1",
"tailwind-variants": "^3.2.2",