Compare commits

...

27 Commits

Author SHA1 Message Date
Mahekk Shaikh 00425e2ed1 cuda : add error checking for cudaMemcpyAsync in argsort (#17599)
* cuda : add error checking for cudaMemcpyAsync in argsort (#12836)

* fix indentation
2025-11-30 08:16:28 +08:00
Acly 385c3da5e6 vulkan : fix FA mask load with bounds check (coopmat2) (#17606) 2025-11-30 01:03:21 +01:00
Xuan-Son Nguyen ab49f094d2 server: move server-context to its own cpp|h (#17595)
* git mv

* add server-context.h

* add server-context.h

* clean up headers

* cont : cleanup

* also expose server_response_reader (to be used by CLI)

* fix windows build

* decouple server_routes and server_http

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-29 22:04:44 +01:00
Haiyue Wang 8c32d9d96d server: explicitly set the function name in lambda (#17538)
As [1] explained, the real debug message will be like:
	"res    operator(): operator() : queue result stop"

Set the name explicitly, the message is easy for debugging:
	"res    operator(): recv : queue result stop"

The left "operator()" is generated by 'RES_DBG() ... __func__'

[1]: https://clang.llvm.org/extra/clang-tidy/checks/bugprone/lambda-function-name.html

Signed-off-by: Haiyue Wang <haiyuewa@163.com>
2025-11-29 18:43:29 +01:00
Igor Smirnov 0874693b44 common : fix json schema with '\' in literals (#17307)
* Fix json schema with '\' in literals

* Add "literal string with escapes" test
2025-11-29 17:06:32 +01:00
Neo Zhang 7d2add51d8 sycl : support to malloc memory on device more than 4GB, update the doc and script (#17566)
Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
2025-11-29 14:59:44 +02:00
ixgbe f698a79c63 ggml: replace hwcap with riscv_hwprobe for RVV detection (#17567)
Signed-off-by: Wang Yang <yangwang@iscas.ac.cn>
2025-11-29 14:56:31 +02:00
Ruben Ortlam 47a268ea50 Vulkan: MMVQ Integer Dot K-Quant and MUL_MAT_ID support (#16900)
* vulkan: split mul_mmq_funcs for mul_mat_vecq use

* add mxfp4 mmvq

* add q2_k mmvq

* add q3_k mmvq

* add q4_k and q5_k mmvq

* add q6_k mmvq

* handle 4x4 quants per mmvq thread

* enable MUL_MAT_ID mmvq support

* enable subgroup optimizations for mul_mat_vec_id shaders

* device tuning

* request prealloc_y sync after quantization

* fix indentation

* fix llvmpipe test failures

* fix mul_mat_id mmvq condition

* fix unused variable warning
2025-11-29 09:37:22 +01:00
Jeff Bolz 59d8d4e963 vulkan: improve topk perf for large k, fix overflow in unit tests (#17582) 2025-11-29 08:39:57 +01:00
Aleksei Nikiforov d82b7a7c1d gguf-py : fix passing non-native endian tensors (editor-gui and new-metadata) (#17553)
gguf_new_metadata.py reads data from reader.
Reader doesn't byteswap tensors to native endianness.
But writer does expect tensors in native endianness to convert them
into requested endianness.

There are two ways to fix this: update reader and do conversion to native endianness and back,
or skip converting endianness in writer in this particular USE-case.

gguf_editor_gui.py doesn't allow editing or viewing tensor data.
Let's go with skipping excessive byteswapping.

If eventually capability to view or edit tensor data is added,
tensor data should be instead byteswapped when reading it.
2025-11-28 20:53:01 +01:00
DAN™ 03914c7ef8 common : move all common_chat_parse_* to chat-parser.cpp. (#17481) 2025-11-28 19:29:36 +01:00
o7si 3ce7a65c2f server: fix: /metrics endpoint returning JSON-escaped Prometheus format (#17386)
* fix: /metrics endpoint returning JSON-escaped Prometheus format

* mod: remove string overload from ok() method
2025-11-28 19:14:00 +01:00
Diego Devesa e072b2052e ggml : add GGML_SCHED_NO_REALLOC option to disable reallocations in ggml_backend_sched (#17276)
* ggml : add GGML_SCHED_NO_REALLOC option to disable reallocations in ggml_backend_sched
Enabled in ggml-ci for testing.

* llama : update worst-case graph for unified cache

* ci : disable op offload in some tests

* fix spelling

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-28 17:33:23 +02:00
R0CKSTAR c6f7a423c8 [MUSA] enable fp16/fast_fp16/bf16_mma on PH1 (#17551)
* [MUSA] enable fp16/fast_fp16/bf16_mma on PH1

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* Update ggml/src/ggml-cuda/fattn-vec.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/fattn-vec.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/fattn-tile.cuh

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Address review comments

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

---------

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-11-28 14:08:29 +01:00
Aman Gupta 2e7ef98f18 ggml-cuda: add stricter checking for fusion (#17568)
* ggml-cuda: make conditions for fusion more explicit

* ggml-cuda: remove size check as std::equal already does it
2025-11-28 20:34:51 +08:00
Fredrik Hultin ddf9f94389 server : add Anthropic Messages API support (#17570)
* server : add Anthropic Messages API support

* remove -@pytest.mark.slow from tool calling/jinja tests

* server : remove unused code and slow/skip on test_anthropic_vision_base64_with_multimodal_model in test_anthropic_api.py

* server : removed redundant n field logic in anthropic_params_from_json

* server : use single error object instead of error_array in streaming response handler for /v1/chat/completions and use unordered_set instead of set in to_json_anthropic_stream()

* server : refactor Anthropic API to use OAI conversion

* make sure basic test always go first

* clean up

* clean up api key check, add test

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-11-28 12:57:04 +01:00
Piotr Wilkin (ilintar) ff55414c42 model : Qwen3 Next (#16095)
* Qwen3 Next - cleaned up version

* Whitespaces and stuff

* Correct minor errors

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Misc. fixes.

* Clean up code, add missing hybrid qualifier

* Did someone transpose the SOLVE_TRI result matrix? Perhaps...

* Whitespace

* Proper tensors for cb calls

* Use llama-graph.h vertical alignment

* BROKEN: chunking

* Set new tensors as inputs.

* Proper chunk logic

* It's the circle of life...

* More shenanigans for n_seq > 1

* Nail in the coffin?

* Fix Windows build

* Eh, one fails on Windows, the other fails on Mac... just use general capture.

* quant : cleanup

* model : cleanup

* qwen3 : cleanup

* cont : cleanup

* cont : cleanup

* ggml : revert change

* qwen3 : cleanup

* cont : cleanup

* Readd cmath

* qwen3 : fix typo

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Usual suspects

* fix my bad suggestion

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-28 12:02:56 +01:00
Johannes Gäßler 73955f7d2a CUDA: no FP16 arithmetic for vector FA kernel (#17558) 2025-11-28 10:29:09 +01:00
Jeff Bolz 35cf8887e1 vulkan: Implement GGML_OP_TRI (#17503)
* vulkan: Implement GGML_OP_TRI

* check types match
2025-11-28 10:07:29 +01:00
Radoslav Gerganov 15d2b46b4d rpc : cache and reuse compute graphs (#15405)
Store the last computed graph and reuse it when possible.
Also do not return response from GRAPH_COMPUTE and assume it always
completes successfully. If this this is not the case, the server closes
the connection. This saves us a network round trip to the server.
2025-11-28 08:33:51 +00:00
yulo 6bca76ff5e HIP: enable mul_mat_f for RDNA4 (#17437)
* enable mmf for rdna4

* move some mmvf to mmf

* revert lds128 for wmma loading

* Revert "revert lds128 for wmma loading"

This reverts commit db9ae8b6b4.

* Revert "enable mmf for rdna4"

This reverts commit 698c9f2418.

* Revert "move some mmvf to mmf"

This reverts commit 99b92bd665.

* enable mul_mat for rdna4

---------

Co-authored-by: zhang hui <you@example.com>
2025-11-28 08:24:30 +01:00
Piotr Wilkin (ilintar) cd0e3a7a3b SOLVE_TRI CUDA kernel for small matrices (#17457) 2025-11-28 12:15:32 +08:00
Neo Zhang Jianyu efaaccdd69 refactor pad_reflect_1d to make the UT case pass (#17204)
Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
2025-11-28 08:50:56 +08:00
Jeff Bolz 4abef75f2c vulkan: Implement SOLVE_TRI (#17486)
* vulkan: Implement SOLVE_TRI

* load B matrix through shared memory

* use FLOAT_TYPE
2025-11-27 15:48:00 +01:00
Georgi Gerganov c386114922 arch : add description about LLM_TENSOR_INFOS (#17550) 2025-11-27 16:34:13 +02:00
Georgi Gerganov 6783b11fb0 models : fix LFM2 tensors (#17548) 2025-11-27 16:04:29 +02:00
matt23654 909072abcf cuda : fix UMA detection on discrete GPUs. (#17537) 2025-11-27 13:35:35 +02:00
89 changed files with 9183 additions and 5105 deletions
+8 -8
View File
@@ -45,7 +45,7 @@ sd=`dirname $0`
cd $sd/../
SRC=`pwd`
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON"
if [ ! -z ${GG_BUILD_METAL} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
@@ -428,10 +428,10 @@ function gg_run_qwen3_0_6b {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl {
qnt="$1"
@@ -523,8 +523,8 @@ function gg_run_embd_bge_small {
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
set +e
}
@@ -564,7 +564,7 @@ function gg_run_rerank_tiny {
model_f16="${path_models}/ggml-model-f16.gguf"
# for this model, the SEP token is "</s>"
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
# sample output
# rerank score 0: 0.029
+968
View File
@@ -13,6 +13,120 @@
using json = nlohmann::ordered_json;
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder,
const common_regex & prefix,
size_t rstrip_prefix = 0) {
static const std::vector<std::vector<std::string>> args_paths = { { "arguments" } };
if (auto res = builder.try_find_regex(prefix)) {
builder.move_back(rstrip_prefix);
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call array");
}
} else {
builder.add_content(builder.consume_rest());
}
}
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
std::string arguments;
if (builder.is_partial()) {
arguments = (json{
{ "code", code + builder.healing_marker() }
})
.dump();
auto idx = arguments.find(builder.healing_marker());
if (idx != std::string::npos) {
arguments.resize(idx);
}
} else {
arguments = (json{
{ "code", code }
})
.dump();
}
return arguments;
}
/**
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static void parse_json_tool_calls(
common_chat_msg_parser & builder,
const std::optional<common_regex> & block_open,
const std::optional<common_regex> & function_regex_start_only,
const std::optional<common_regex> & function_regex,
const common_regex & close_regex,
const std::optional<common_regex> & block_close,
bool allow_raw_python = false,
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name =
nullptr) {
auto parse_tool_calls = [&]() {
size_t from = std::string::npos;
auto first = true;
while (true) {
auto start_pos = builder.pos();
auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) :
function_regex ? builder.try_find_regex(*function_regex, from) :
std::nullopt;
if (res) {
std::string name;
if (get_function_name) {
name = get_function_name(*res);
} else {
GGML_ASSERT(res->groups.size() == 2);
name = builder.str(res->groups[1]);
}
first = false;
if (name.empty()) {
// get_function_name signalled us that we should skip this match and treat it as content.
from = res->groups[0].begin + 1;
continue;
}
from = std::string::npos;
auto maybe_raw_python = name == "python" && allow_raw_python;
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) {
if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(close_regex);
}
continue;
}
if (maybe_raw_python) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
if (!builder.add_tool_call(name, "", arguments)) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
return;
}
throw common_chat_msg_partial_exception("incomplete tool call");
} else {
builder.move_to(start_pos);
}
break;
}
if (block_close) {
builder.consume_regex(*block_close);
}
builder.consume_spaces();
builder.add_content(builder.consume_rest());
};
if (block_open) {
if (auto res = builder.try_find_regex(*block_open)) {
parse_tool_calls();
} else {
builder.add_content(builder.consume_rest());
}
} else {
parse_tool_calls();
}
}
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
@@ -532,3 +646,857 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
void common_chat_msg_parser::clear_tools() {
result_.tool_calls.clear();
}
/**
* All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below
* to reduce incremental compile time for parser changes.
*/
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
static const std::vector<std::vector<std::string>> args_paths = {
{"tool_call", "arguments"},
{"tool_calls", "arguments"},
};
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
if (data.value.contains("tool_calls")) {
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool calls");
}
} else if (data.value.contains("tool_call")) {
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (data.value.contains("response")) {
const auto & response = data.value.at("response");
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
if (data.is_partial) {
throw common_chat_msg_partial_exception("incomplete response");
}
} else {
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
}
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("[THINK]", "[/THINK]");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
if (tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
if (!builder.try_find_regex(end_response_regex)) {
builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);
common_healing_marker healing_marker;
json args = json::object();
while (true) {
if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
auto arg_name = builder.str(arg_res->groups[1]);
auto partial = builder.consume_json();
args[arg_name] = partial.json;
healing_marker.marker = partial.healing_marker.marker;
healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
builder.consume_spaces();
if (!builder.try_consume_literal(",")) {
break;
}
} else {
break;
}
}
builder.consume_literal(")");
builder.consume_spaces();
auto arguments = args.dump();
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
return;
}
}
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ function_regex,
/* function_regex= */ std::nullopt,
close_regex,
std::nullopt);
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
static const common_regex function_regex("(?:<tool▁call▁begin>)?function<tool▁sep>([^\n]+)\n```json\n");
static const common_regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
static const common_regex function_regex("(?:<tool▁call▁begin>)?([^\\n<]+)(?:<tool▁sep>)");
static const common_regex close_regex("(?:[\\s]*)?<tool▁call▁end>");
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
if (!builder.syntax().parse_tool_calls) {
LOG_DBG("%s: not parse_tool_calls\n", __func__);
builder.add_content(builder.consume_rest());
return;
}
LOG_DBG("%s: parse_tool_calls\n", __func__);
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
// First try to parse using the standard reasoning parsing method
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
auto start_pos = builder.pos();
auto found_end_think = builder.try_find_literal("</think>");
builder.move_to(start_pos);
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
// If reasoning was parsed successfully, the remaining content is regular content
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
// </think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>NAME\n```json\nJSON\n```<tool▁call▁end><tool▁calls▁end>
common_chat_parse_deepseek_v3_1_content(builder);
} else {
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
return;
}
// If no reasoning tags found, check if we should treat everything as reasoning
if (builder.syntax().thinking_forced_open) {
// If thinking is forced open but no tags found, treat everything as reasoning
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
builder.add_reasoning_content(builder.consume_rest());
} else {
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
// <tool▁call▁begin>NAME<tool▁sep>JSON<tool▁call▁end>
common_chat_parse_deepseek_v3_1_content(builder);
}
}
}
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<minimax:tool_call>",
/* form.tool_start = */ "<invoke name=\"",
/* form.tool_sep = */ "\">",
/* form.key_start = */ "<parameter name=\"",
/* form.key_val_sep = */ "\">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</invoke>",
/* form.scope_end = */ "</minimax:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_call>";
form.tool_start = "<function=";
form.tool_sep = ">";
form.key_start = "<parameter=";
form.key_val_sep = ">";
form.val_end = "</parameter>";
form.tool_end = "</function>";
form.scope_end = "</tool_call>";
form.trim_raw_argval = true;
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<|tool_calls_section_begin|>";
form.tool_start = "<|tool_call_begin|>";
form.tool_sep = "<|tool_call_argument_begin|>{";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}<|tool_call_end|>";
form.scope_end = "<|tool_calls_section_end|>";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_calls>[";
form.tool_start = "{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}, ";
form.scope_end = "]</tool_calls>";
form.raw_argval = false;
form.last_val_end = "";
form.last_tool_end = "}";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
}
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "";
form.tool_start = "<tool_call>\n{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}\n</tool_call>";
form.scope_end = "";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
static const common_regex start_regex("<\\|start\\|>assistant");
static const common_regex analysis_regex("<\\|channel\\|>analysis");
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
static const common_regex preamble_regex("<\\|channel\\|>commentary");
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
auto consume_end = [&](bool include_end = false) {
if (auto res = builder.try_find_literal("<|end|>")) {
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
}
return builder.consume_rest();
};
auto handle_tool_call = [&](const std::string & name) {
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
if (builder.syntax().parse_tool_calls) {
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
};
auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
auto match = regex.search(input, 0, true);
if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
return match;
}
return std::nullopt;
};
do {
auto header_start_pos = builder.pos();
auto content_start = builder.try_find_literal("<|message|>");
if (!content_start) {
throw common_chat_msg_partial_exception("incomplete header");
}
auto header = content_start->prelude;
if (auto match = regex_match(tool_call1_regex, header)) {
auto group = match->groups[1];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (auto match = regex_match(tool_call2_regex, header)) {
auto group = match->groups[2];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (regex_match(analysis_regex, header)) {
builder.move_to(header_start_pos);
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
builder.add_content(consume_end(true));
} else {
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
}
continue;
}
if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
builder.add_content(consume_end());
continue;
}
// Possibly a malformed message, attempt to recover by rolling
// back to pick up the next <|start|>
LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str());
builder.move_to(header_start_pos);
} while (builder.try_find_regex(start_regex, std::string::npos, false));
auto remaining = builder.consume_rest();
if (!remaining.empty()) {
LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str());
}
}
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "",
/* form.tool_start = */ "<tool_call>",
/* form.tool_sep = */ "",
/* form.key_start = */ "<arg_key>",
/* form.key_val_sep = */ "</arg_key>",
/* form.val_end = */ "</arg_value>",
/* form.tool_end = */ "</tool_call>",
/* form.scope_end = */ "",
/* form.key_val_sep2 = */ "<arg_value>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape(" functools["));
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
}
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
static const common_regex close_regex(R"(\s*)");
parse_json_tool_calls(
builder,
std::nullopt,
function_regex_start_only,
function_regex,
close_regex,
std::nullopt,
/* allow_raw_python= */ true,
/* get_function_name= */ [&](const auto & res) -> std::string {
auto at_start = res.groups[0].begin == 0;
auto name = builder.str(res.groups[1]);
if (!name.empty() && name.back() == '{') {
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
builder.move_back(1);
}
auto idx = name.find_last_not_of("\n{");
name = name.substr(0, idx + 1);
if (at_start && name == "all") {
return "";
}
return name;
});
}
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
static const common_regex function_regex(R"(<function=(\w+)>)");
static const common_regex close_regex(R"(</function>)");
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
std::nullopt);
if (auto res = builder.try_find_regex(python_tag_regex)) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
return;
}
}
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex open_regex(
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);
while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
const auto & open_tag = res->groups[2];
std::string close_tag;
if (!res->groups[3].empty()) {
builder.move_to(res->groups[3].begin);
close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
} else {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
} else {
auto function_name = builder.str(res->groups[4]);
if (function_name.empty()) {
function_name = builder.str(res->groups[5]);
}
GGML_ASSERT(!function_name.empty());
close_tag = "</function>";
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
}
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
static const common_regex start_think_regex(regex_escape("<think>"));
static const common_regex end_think_regex(regex_escape("</think>"));
// Granite models output partial tokens such as "<" and "<think".
// By leveraging try_consume_regex()/try_find_regex() throwing
// common_chat_msg_partial_exception for these partial tokens,
// processing is interrupted and the tokens are not passed to add_content().
if (auto res = builder.try_consume_regex(start_think_regex)) {
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
builder.try_find_regex(end_think_regex, std::string::npos, false);
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
}
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags
static const common_regex start_response_regex(regex_escape("<response>"));
static const common_regex end_response_regex(regex_escape("</response>"));
// Granite models output partial tokens such as "<" and "<response".
// Same hack as reasoning parsing.
if (builder.try_consume_regex(start_response_regex)) {
builder.try_find_regex(end_response_regex);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
builder.consume_spaces();
if (!builder.try_consume_literal("<|tools_suffix|>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
for (const auto & value : tool_calls_data.json) {
if (value.is_object()) {
builder.add_tool_call_short_form(value);
}
}
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
// Loop through all tool calls
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
builder.move_to(res->groups[0].end);
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
auto tool_calls_data = builder.consume_json();
// Consume end marker
builder.consume_spaces();
if (!builder.try_consume_regex(tool_call_end_regex)) {
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
}
// Process each tool call in the array
if (tool_calls_data.json.is_array()) {
for (const auto & tool_call : tool_calls_data.json) {
if (!tool_call.is_object()) {
throw common_chat_msg_partial_exception("Tool call must be an object");
}
if (!tool_call.contains("name")) {
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
}
std::string function_name = tool_call.at("name");
std::string arguments = "{}";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else if (tool_call.at("arguments").is_string()) {
arguments = tool_call.at("arguments");
}
}
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
} else {
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
}
// Consume any trailing whitespace after this tool call
builder.consume_spaces();
}
// Consume any remaining content after all tool calls
auto remaining = builder.consume_rest();
if (!string_strip(remaining).empty()) {
builder.add_content(remaining);
}
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<seed:tool_call>",
/* form.tool_start = */ "<function=",
/* form.tool_sep = */ ">",
/* form.key_start = */ "<parameter=",
/* form.key_val_sep = */ ">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</function>",
/* form.scope_end = */ "</seed:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
}
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest());
}
static void common_chat_parse(common_chat_msg_parser & builder) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
case COMMON_CHAT_FORMAT_GENERIC:
common_chat_parse_generic(builder);
break;
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
common_chat_parse_mistral_nemo(builder);
break;
case COMMON_CHAT_FORMAT_MAGISTRAL:
common_chat_parse_magistral(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X:
common_chat_parse_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
common_chat_parse_deepseek_v3_1(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
common_chat_parse_functionary_v3_2(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
common_chat_parse_functionary_v3_1_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
common_chat_parse_hermes_2_pro(builder);
break;
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
common_chat_parse_firefunction_v2(builder);
break;
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
case COMMON_CHAT_FORMAT_APERTUS:
common_chat_parse_apertus(builder);
break;
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
common_chat_parse_lfm2(builder);
break;
case COMMON_CHAT_FORMAT_MINIMAX_M2:
common_chat_parse_minimax_m2(builder);
break;
case COMMON_CHAT_FORMAT_GLM_4_5:
common_chat_parse_glm_4_5(builder);
break;
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
common_chat_parse_qwen3_coder_xml(builder);
break;
case COMMON_CHAT_FORMAT_APRIEL_1_5:
common_chat_parse_apriel_1_5(builder);
break;
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
common_chat_parse_xiaomi_mimo(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
LOG_DBG("Partial parse: %s\n", ex.what());
if (!is_partial) {
builder.clear_tools();
builder.move_to(0);
common_chat_parse_content_only(builder);
}
}
auto msg = builder.result();
if (!is_partial) {
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
}
return msg;
}
-952
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
}
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
};
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
+30
View File
@@ -4183,6 +4183,36 @@ class Qwen3MoeModel(Qwen2MoeModel):
super().set_vocab()
@ModelBase.register("Qwen3NextForCausalLM")
class Qwen3NextModel(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"])
self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"])
self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("mtp"):
return [] # ignore MTP layers for now
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif "conv1d" in name:
data_torch = data_torch.squeeze()
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
data_torch = data_torch + 1
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("RND1")
class RND1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.RND1
+13
View File
@@ -42,6 +42,9 @@ The following releases are verified and recommended:
## News
- 2025.11
- Support malloc memory on device more than 4GB.
- 2025.2
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|GPU|Base tokens/s|Increased tokens/s|Percent|
@@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
## Known Issues
@@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.|
| The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;<br>Alternatively, use more than one device to load model.|
- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device`
You need to enable to support 4GB memory malloc by:
```
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
```
### **GitHub contribution**:
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.
+4 -3
View File
@@ -104,12 +104,16 @@ int main(int argc, char ** argv) {
params.embedding = true;
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts
if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true;
params.n_parallel = n_seq_max;
}
// utilize the full context
@@ -123,9 +127,6 @@ int main(int argc, char ** argv) {
params.n_ubatch = params.n_batch;
}
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
llama_backend_init();
llama_numa_init(params.numa);
+2 -2
View File
@@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]'
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
@@ -4,6 +4,11 @@ set -e
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
if [ -z "$MODEL_TESTING_PROMPT"]; then
MODEL_TESTING_PROMPT="Hello, my name is"
fi
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
@@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
fi
echo $CONVERTED_MODEL
echo $MODEL_TESTING_PROMPT
cmake --build ../../build --target llama-logits -j8
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"
@@ -184,8 +184,12 @@ model_name = os.path.basename(model_path)
# of using AutoModelForCausalLM.
print(f"Model class: {model.__class__.__name__}")
prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
device = next(model.parameters()).device
if os.getenv("MODEL_TESTING_PROMPT"):
prompt = os.getenv("MODEL_TESTING_PROMPT")
else:
prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}")
+3
View File
@@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=99
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "use $GGML_SYCL_DEVICE as main GPU"
+6 -3
View File
@@ -6,7 +6,7 @@
# If you want more control, DPC++ Allows selecting a specific device through the
# following environment variable
#export ONEAPI_DEVICE_SELECTOR="level_zero:0"
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
@@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "Using $GGML_SYCL_DEVICE as the main GPU"
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
else
#use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT}
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
fi
+2
View File
@@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0
+3 -1
View File
@@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99
+1
View File
@@ -183,6 +183,7 @@ endif()
# ggml core
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
option(GGML_CPU "ggml: enable CPU backend" ON)
option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF)
# 3rd party libs / backends
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
+1 -1
View File
@@ -8,7 +8,7 @@ extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_MINOR_VERSION 5
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16
+4
View File
@@ -221,6 +221,10 @@ if (GGML_BACKEND_DL)
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
endif()
if (GGML_SCHED_NO_REALLOC)
target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)
endif()
add_library(ggml
ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml)
+8 -3
View File
@@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
}
if (realloc) {
#ifndef NDEBUG
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
{
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
if (cur_size > 0) {
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n",
__func__, ggml_backend_buft_name(galloc->bufts[i]),
cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
}
}
#endif
ggml_vbuffer_free(galloc->buffers[i]);
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
if (galloc->buffers[i] == NULL) {
+9 -3
View File
@@ -1395,14 +1395,20 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
// allocate graph
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
#ifdef GGML_SCHED_NO_REALLOC
GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__);
#endif
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
// the re-allocation may cause the split inputs to be moved to a different address
// synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
for (int i = 0; i < sched->n_backends; i++) {
ggml_backend_synchronize(sched->backends[i]);
}
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__);
+11 -8
View File
@@ -1,20 +1,23 @@
#include "ggml-backend-impl.h"
#if defined(__riscv) && __riscv_xlen == 64
#include <sys/auxv.h>
//https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24
#ifndef COMPAT_HWCAP_ISA_V
#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A'))
#endif
#include <asm/hwprobe.h>
#include <asm/unistd.h>
#include <unistd.h>
struct riscv64_features {
bool has_rvv = false;
riscv64_features() {
uint32_t hwcap = getauxval(AT_HWCAP);
struct riscv_hwprobe probe;
probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
probe.value = 0;
has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V);
int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
if (0 == ret) {
has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
}
}
};
+2 -1
View File
@@ -9766,7 +9766,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
}
const float diag = A_batch[i00 * n + i00];
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
}
}
+1 -1
View File
@@ -44,7 +44,7 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
size_t temp_storage_bytes = 0;
+18 -10
View File
@@ -84,12 +84,12 @@
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
# define GGML_CUDA_USE_CUB
@@ -212,9 +212,9 @@ static const char * cu_get_error_str(CUresult err) {
#define GGML_USE_VMM
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
@@ -250,12 +250,14 @@ static const char * cu_get_error_str(CUresult err) {
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
}
static bool fast_fp16_available(const int cc) {
return GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -272,7 +274,9 @@ static bool fp16_mma_hardware_available(const int cc) {
}
static bool bf16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
}
static bool fp32_mma_hardware_available(const int cc) {
@@ -558,8 +562,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
acc += v.y*u.y;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
#define V_DOT2_F32_F16_AVAILABLE
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#ifdef V_DOT2_F32_F16_AVAILABLE
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
#else
#ifdef FAST_FP16_AVAILABLE
@@ -571,7 +579,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
acc += tmpv.x * tmpu.x;
acc += tmpv.y * tmpu.y;
#endif // FAST_FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
#endif // V_DOT2_F32_F16_AVAILABLE
}
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
+4 -1
View File
@@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
}
}
}
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
nb12, nb13);
}
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
@@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda(
ne00n = ne00;
ne01n = ne01;
ne02n = ne02;
} else if (nb00 > nb02) {
} else {
ne00n = ne00;
ne01n = ne01*ne02;
ne02n = 1;
+2 -2
View File
@@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#else
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#endif // FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
+1 -1
View File
@@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float KQ_sum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val;
tmp[i0/(np*warp_size)][jc1] = val;
+18 -18
View File
@@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
#else
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
@@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
constexpr int ne_KQ = ncols*D;
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
#else
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
float KQ_max[ncols];
float KQ_sum[ncols];
@@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
}
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
#else
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
if constexpr (Q_q8_1) {
@@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec(
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
tmp_q_i32[i] = 0;
}
}
@@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
__syncthreads();
} else {
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 scale_h2 = make_half2(scale, scale);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
Q_reg[j][k].y *= scale;
}
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
@@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec(
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
KQ_reg[j] = sum;
}
}
@@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
KQ[j*nthreads + tid] = KQ_reg[j];
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
@@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
#ifndef GGML_USE_HIP
@@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
}
}
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
@@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
@@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
@@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
KQ_max[j_VKQ] = kqmax_new;
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
@@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
KQ_sum[j_VKQ] *= kqmax_scale;
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
+22 -12
View File
@@ -53,6 +53,7 @@
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml.h"
#include <algorithm>
@@ -2717,6 +2718,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_OPT_STEP_SGD:
ggml_cuda_opt_step_sgd(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
default:
return false;
}
@@ -3046,7 +3050,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
if (ops.size() == topk_moe_ops_with_norm.size() &&
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
const std::initializer_list<enum ggml_op> & list2) {
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
};
if (is_equal(topk_moe_ops_with_norm, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
@@ -3056,8 +3065,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == topk_moe_ops.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
@@ -3065,7 +3073,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
@@ -3081,9 +3089,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
@@ -3095,9 +3102,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
@@ -3107,7 +3113,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
const ggml_tensor * rope = cgraph->nodes[node_idx];
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
@@ -3837,7 +3845,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
// Check if UMA is explicitly enabled via environment variable
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
bool is_uma = prop.unifiedAddressing > 0 || uma_env;
bool is_uma = prop.integrated > 0 || uma_env;
if (is_uma) {
// For UMA systems (like DGX Spark), use system memory info
@@ -4255,6 +4263,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
default:
return false;
}
+2 -2
View File
@@ -889,8 +889,8 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
#else
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+1 -1
View File
@@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
} else {
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
if (src1_ncols > 16) {
return false;
}
}
+203
View File
@@ -0,0 +1,203 @@
#include "common.cuh"
#include "ggml.h"
#include "solve_tri.cuh"
#define MAX_N_FAST 64
#define MAX_K_FAST 32
// ======================
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
// ======================
// When ncols_template == 0 the bounds for the loops in this function are not
// known and can't be unrolled. As we want to keep pragma unroll for all other
// cases we supress the clang transformation warning here.
#ifdef __clang__
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <int n_template, int k_template>
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n_arg,
const int k_arg) {
const int n = n_template == 0 ? n_arg : n_template;
const int k = k_template == 0 ? k_arg : k_template;
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int col_idx = threadIdx.y;
if (col_idx >= k) {
return;
}
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
int i0 = i + offset;
if (i0 < n * n) {
sA[i0] = A_batch[i0];
}
}
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
}
}
__syncthreads();
#pragma unroll
for (int row = 0; row < n; ++row) {
float sum = 0.0f;
{
int j = lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
if (row >= WARP_SIZE) {
int j = WARP_SIZE + lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
sum = warp_reduce_sum(sum);
if (lane == 0) {
const float b_val = sXt[col_idx * n + row];
const float a_diag = sA[row * n + row];
// no safeguards for division by zero because that indicates corrupt
// data anyway
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
}
}
}
#ifdef __clang__
# pragma clang diagnostic pop
#endif // __clang__
static void solve_tri_f32_cuda(const float * A,
const float * B,
float * X,
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t nb02,
size_t nb03,
size_t nb12,
size_t nb13,
size_t nb2,
size_t nb3,
cudaStream_t stream) {
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
dim3 threads(WARP_SIZE, k);
dim3 grid(ne02 * ne03);
if (n == 64) {
switch (k) {
case 32:
solve_tri_f32_fast<64, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 16:
solve_tri_f32_fast<64, 16>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 14:
solve_tri_f32_fast<64, 14>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 12:
solve_tri_f32_fast<64, 12>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 10:
solve_tri_f32_fast<64, 10>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 8:
solve_tri_f32_fast<64, 8>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 6:
solve_tri_f32_fast<64, 6>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 4:
solve_tri_f32_fast<64, 4>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 2:
solve_tri_f32_fast<64, 2>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 1:
solve_tri_f32_fast<64, 1>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
default:
solve_tri_f32_fast<0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
} else { // run general case
solve_tri_f32_fast<0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
}
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
ggml_is_contiguous(src0);
ggml_is_contiguous(src1);
const int64_t n = src0->ne[0];
const int64_t k = src1->ne[0];
GGML_ASSERT(n <= 64);
GGML_ASSERT(k <= 32);
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float), ctx.stream());
}
+3
View File
@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+91 -20
View File
@@ -106,6 +106,7 @@ enum rpc_cmd {
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT,
RPC_CMD_GRAPH_RECOMPUTE,
RPC_CMD_COUNT,
};
@@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp {
uint8_t result;
};
struct rpc_msg_graph_compute_rsp {
uint8_t result;
};
struct rpc_msg_get_device_memory_req {
uint32_t device;
};
@@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem;
uint64_t total_mem;
};
struct rpc_msg_graph_recompute_req {
uint32_t device;
};
#pragma pack(pop)
// RPC data structures
@@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context {
size_t max_size;
};
struct graph_cache {
bool is_cached(const ggml_cgraph * cgraph) {
if ((int)last_graph.size() != cgraph->n_nodes) {
return false;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
return false;
}
}
return true;
}
void add(const ggml_cgraph * cgraph) {
last_graph.resize(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; i++) {
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
}
}
std::vector<ggml_tensor> last_graph;
};
struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
graph_cache gc;
};
struct ggml_backend_rpc_buffer_context {
@@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return (enum ggml_status)response.result;
GGML_ASSERT(cgraph->n_nodes > 0);
bool reuse = rpc_ctx->gc.is_cached(cgraph);
if (reuse) {
rpc_msg_graph_recompute_req request;
request.device = rpc_ctx->device;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
RPC_STATUS_ASSERT(status);
} else {
rpc_ctx->gc.add(cgraph);
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
RPC_STATUS_ASSERT(status);
}
return GGML_STATUS_SUCCESS;
}
static ggml_backend_i ggml_backend_rpc_interface = {
@@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name
/* .name = */ dev_name,
/* .gc = */ {},
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
@@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
class rpc_server {
public:
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
: backends(std::move(backends)), cache_dir(cache_dir) {
rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
: backends(std::move(all_backends)), cache_dir(cache_dir) {
stored_graphs.resize(backends.size());
}
~rpc_server();
@@ -936,11 +976,17 @@ public:
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input);
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
struct stored_graph {
ggml_context_ptr ctx_ptr;
ggml_cgraph * graph;
};
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -953,6 +999,8 @@ private:
std::vector<ggml_backend_t> backends;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
// store the last computed graph for each backend
std::vector<stored_graph> stored_graphs;
};
void rpc_server::hello(rpc_msg_hello_rsp & response) {
@@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
return result;
}
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
// serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < 2*sizeof(uint32_t)) {
@@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
}
}
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
response.result = status;
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
stored_graphs[device].graph = graph;
return true;
}
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
uint32_t device = request.device;
if (device >= backends.size()) {
return false;
}
if (stored_graphs[device].graph == nullptr) {
return false;
}
ggml_cgraph * graph = stored_graphs[device].graph;
LOG_DBG("[%s] device: %u\n", __func__, device);
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
return true;
}
@@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!recv_msg(sockfd, input)) {
return;
}
rpc_msg_graph_compute_rsp response;
if (!server.graph_compute(input, response)) {
if (!server.graph_compute(input)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
break;
}
case RPC_CMD_GRAPH_RECOMPUTE: {
rpc_msg_graph_recompute_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.graph_recompute(request)) {
return;
}
break;
+6 -2
View File
@@ -91,7 +91,10 @@ if (GGML_SYCL_F16)
add_compile_definitions(GGML_SYCL_F16)
endif()
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
if (GGML_SYCL_TARGET STREQUAL "INTEL")
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required)
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
# INFO: Allowed Sub_group_sizes are not consistent through all
@@ -100,7 +103,8 @@ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
# Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
else()
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
# default for other target
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
endif()
if (GGML_SYCL_GRAPH)
+26
View File
@@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias,
return dpct::pow(base, exph);
}
static const sycl::uint3 init_fastdiv_values(uint32_t d) {
GGML_ASSERT(d != 0);
uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
return sycl::uint3(mp, L, d);
}
static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
return (hi + n) >> fastdiv_values.y();
}
static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t div_val = fastdiv(n, fastdiv_values);
const uint32_t mod_val = n - div_val * fastdiv_values.z();
return sycl::uint2(div_val, mod_val);
}
#endif // GGML_SYCL_COMMON_HPP
-3
View File
@@ -515,9 +515,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
GGML_TENSOR_BINARY_OP_LOCALS01;
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+79 -51
View File
@@ -1,72 +1,100 @@
#include "pad_reflect_1d.hpp"
void pad_reflect_1d_f32(const float* src,float* dst,
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const sycl::nd_item<3> &item_ct1){
static void pad_reflect_1d_kernel_f32(
const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0,
const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02,
const int64_t ne03, const int64_t nb00, const int64_t nb01,
const int64_t nb02, const int64_t nb03, const int64_t nb0,
const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0,
const int p1, sycl::nd_item<3> item_ct1) {
const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
const int i1 = item_ct1.get_group(1);
const int g2 = item_ct1.get_group(2);
const int i2 = g2 % ne02;
const int i3 = g2 / ne02;
const int64_t i3 = item_ct1.get_group(0);
const int64_t i2 = item_ct1.get_group(1);
if (i0 >= p0 + ne0 + p1) return;
const sycl::uint2 div_mod_packed =
fast_div_modulo(item_ct1.get_group(2), ne01);
const int64_t tile1 = div_mod_packed.y();
const int64_t tile0 = div_mod_packed.x();
const int64_t i1 = tile1;
const int64_t i0 =
item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2);
int t = i0 - p0;
int period = 2 * ne0 -2;
int m = t % period;
m += (m < 0) * period;
int center = ne0 -1;
int srci0 = center - abs(center - m);
if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) {
return;
}
int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
dst[offest_dst] = src[offest_src];
const char *src0_ptr =
(const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
int64_t src_idx;
if (rel_i0 < 0) {
// Left padding - reflect
src_idx = -rel_i0;
} else if (rel_i0 < ne00) {
// Middle - copy
src_idx = rel_i0;
} else {
// Right padding - reflect
src_idx = 2 * ne00 - 2 - rel_i0;
}
const float value = *(const float *)(src0_ptr + src_idx * nb00);
*(float *)(dst_ptr + i0 * nb0) = value;
GGML_UNUSED(p1);
}
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx,
ggml_tensor *dst) {
const ggml_tensor * src0 = dst->src[0];
queue_ptr stream = ctx.stream();
const ggml_tensor *src0 = dst->src[0];
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *) dst->op_params;
const int32_t *opts = (const int32_t *)dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];
const int64_t ne0 = src0->ne[0];
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const sycl::uint3 ne01_packed = init_fastdiv_values(ne01);
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne00 = dst->ne[0];
const int64_t ne01 = dst->ne[1];
const int64_t ne02 = dst->ne[2];
const int64_t ne03 = dst->ne[3];
const int64_t ne0 = dst->ne[0];
const int64_t nb00 = dst->nb[0];
const int64_t nb01 = dst->nb[1];
const int64_t nb02 = dst->nb[2];
const int64_t nb03 = dst->nb[3];
const int64_t nb0 = src0->nb[0];
const int64_t nb1 = src0->nb[1];
const int64_t nb2 = src0->nb[2];
const int64_t nb3 = src0->nb[3];
GGML_ASSERT(ne0 == ne00 + p0 + p1);
int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);
constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE;
const int64_t tiles0 = (ne0 + bx - 1) / bx;
const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02,
(unsigned)ne03);
const dpct::dim3 block_dims((unsigned)bx, 1, 1);
stream->parallel_for(
sycl::nd_range<3>(global,
local),
[=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32(
(const float *) src0->data, (float *) dst->data,
ne0, ne02, p0, p1,
nb0, nb1, nb2, nb3,
nb00, nb01, nb02, nb03
, item_ct1);
});
stream->submit([&](sycl::handler &cgh) {
auto src0_data_ct0 = src0->data;
auto dst_data_ct1 = dst->data;
auto src0_nb_ct7 = src0->nb[0];
auto src0_nb_ct8 = src0->nb[1];
auto src0_nb_ct9 = src0->nb[2];
auto src0_nb_ct10 = src0->nb[3];
auto dst_nb_ct11 = dst->nb[0];
auto dst_nb_ct12 = dst->nb[1];
auto dst_nb_ct13 = dst->nb[2];
auto dst_nb_ct14 = dst->nb[3];
cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
pad_reflect_1d_kernel_f32(
src0_data_ct0, dst_data_ct1, ne0, ne00,
ne01_packed, ne02, ne03, src0_nb_ct7,
src0_nb_ct8, src0_nb_ct9, src0_nb_ct10,
dst_nb_ct11, dst_nb_ct12, dst_nb_ct13,
dst_nb_ct14, p0, p1, item_ct1);
});
});
}
+2
View File
@@ -3,6 +3,8 @@
#include "common.hpp"
#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
#endif // GGML_SYCL_PAD_REFLECT_1D_HPP
+313 -56
View File
@@ -399,6 +399,18 @@ struct vk_conv2d_pipeline_state {
}
};
struct vk_solve_tri_pipeline_state {
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
: N(N), K(K) {}
uint32_t N, K;
bool operator<(const vk_solve_tri_pipeline_state &b) const {
return std::tie(N, K) <
std::tie(b.N, b.K);
}
};
enum shader_reduction_mode {
SHADER_REDUCTION_MODE_SHMEM,
SHADER_REDUCTION_MODE_HYBRID,
@@ -601,9 +613,10 @@ struct vk_device_struct {
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
@@ -637,6 +650,7 @@ struct vk_device_struct {
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_tri[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
@@ -711,6 +725,7 @@ struct vk_device_struct {
vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32;
@@ -1597,7 +1612,7 @@ class vk_perf_logger {
}
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->src[0]->ne[1];
const uint64_t n = node->ne[1];
const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2];
const uint64_t k = node->src[1]->ne[0];
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
std::string name = ggml_op_name(node->op);
@@ -3511,13 +3526,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
// the number of rows computed per shader depends on GPU model and quant
uint32_t rm_stdq = 1;
uint32_t rm_kq = 2;
uint32_t rm_stdq_int = 1;
uint32_t rm_kq_int = 1;
if (device->vendor_id == VK_VENDOR_ID_AMD) {
if (device->architecture == AMD_GCN) {
rm_stdq = 2;
rm_kq = 4;
rm_stdq_int = 4;
}
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
} else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
rm_stdq = 2;
rm_stdq_int = 2;
}
uint32_t rm_iq = 2 * rm_kq;
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
@@ -3598,39 +3618,73 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
}
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
}
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
GGML_UNUSED(rm_stdq_int);
GGML_UNUSED(rm_kq_int);
#endif
// dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -3863,6 +3917,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
}
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -4002,6 +4059,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
for (auto &s : device->pipeline_solve_tri_f32) {
const vk_solve_tri_pipeline_state &state = s.first;
ggml_vk_create_pipeline(
device, s.second, "solve_tri_f32",
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
}
#define IM2COL(bda) \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
@@ -5428,6 +5493,12 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
break;
default:
return nullptr;
@@ -5567,9 +5638,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
}
}
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) {
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
GGML_ASSERT(b_type == GGML_TYPE_F32);
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1);
if (b_type == GGML_TYPE_Q8_1) {
switch (a_type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
break;
default:
return nullptr;
}
}
switch (a_type) {
case GGML_TYPE_F32:
@@ -5600,7 +5690,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
return nullptr;
}
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
// heuristic to choose workgroup size
uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
// Prefer larger workgroups when M is small, to spread the work out more
// and keep more SMs busy.
// q6_k seems to prefer small workgroup size even for "medium" values of M.
if (a_type == GGML_TYPE_Q6_K) {
if (m < 4096 && k >= 1024) {
dmmv_wg = DMMV_WG_SIZE_LARGE;
}
} else {
if (m <= 8192 && k >= 1024) {
dmmv_wg = DMMV_WG_SIZE_LARGE;
}
}
}
if (b_type == GGML_TYPE_Q8_1) {
if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
}
return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type];
}
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type];
}
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
@@ -6792,20 +6906,35 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
return false;
}
// General performance issue with q3_k and q6_k due to 2-byte alignment
if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
return false;
}
// MMVQ is generally good for batches
if (n > 1) {
return true;
}
// Quantization overhead is not worth it for small k
switch (device->vendor_id) {
case VK_VENDOR_ID_NVIDIA:
if (k <= 4096) {
return false;
}
switch (src0_type) {
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q8_0:
return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
default:
return true;
}
case VK_VENDOR_ID_AMD:
if (k < 2048) {
return false;
}
switch (src0_type) {
case GGML_TYPE_Q8_0:
return device->architecture == vk_device_architecture::AMD_GCN;
@@ -6813,6 +6942,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
return true;
}
case VK_VENDOR_ID_INTEL:
if (k < 2048) {
return false;
}
switch (src0_type) {
// From tests on A770 Linux, may need more tuning
case GGML_TYPE_Q4_0:
@@ -6826,7 +6959,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
}
GGML_UNUSED(m);
GGML_UNUSED(k);
}
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -7549,7 +7681,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (x_non_contig || qx_needs_dequant) {
ctx->prealloc_x_need_sync = true;
}
if (y_non_contig) {
if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true;
}
}
@@ -7575,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t ne10 = src1->ne[0];
const uint64_t ne11 = src1->ne[1];
// const uint64_t ne12 = src1->ne[2];
const uint64_t ne12 = src1->ne[2];
// const uint64_t ne13 = src1->ne[3];
const uint64_t nei0 = ids->ne[0];
@@ -7592,19 +7724,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = ggml_nelements(src1);
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
@@ -7616,11 +7736,38 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
} else {
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
}
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
// Check for mmq first
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr;
vk_pipeline to_q8_1 = nullptr;
if (dmmv == nullptr) {
// Fall back to f16 dequant mul mat
dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00);
quantize_y = false;
}
if (quantize_y) {
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
}
const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
GGML_ASSERT(dmmv != nullptr);
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = ggml_nelements(src1);
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
{
if (
(qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
@@ -7631,7 +7778,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz) {
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
ctx->prealloc_size_y = y_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
@@ -7643,6 +7790,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
}
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
}
@@ -7658,7 +7808,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
} else {
d_X = d_Qx;
}
if (qy_needs_dequant) {
if (qy_needs_dequant || quantize_y) {
d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
} else {
d_Y = d_Qy;
@@ -7686,6 +7836,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ctx->prealloc_y_last_tensor_used = src1;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
}
uint32_t stride_batch_y = ne10*ne11;
@@ -7747,7 +7908,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
}
if (y_non_contig) {
if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true;
}
}
@@ -8269,6 +8430,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_TRI:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
@@ -8496,6 +8663,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_cumsum_f32;
}
return nullptr;
case GGML_OP_SOLVE_TRI:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
if (it != ctx->device->pipeline_solve_tri_f32.end()) {
pipeline = it->second;
} else {
ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
}
}
return pipeline;
}
return nullptr;
case GGML_OP_ARGMAX:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argmax_f32;
@@ -8832,6 +9019,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { nr, 1, 1 };
}
} break;
case GGML_OP_SOLVE_TRI:
{
uint32_t nr = (uint32_t)(ne02 * ne03);
if (nr > 262144) {
elements = { 512, 512, CEIL_DIV(nr, 262144) };
} else if (nr > 512) {
elements = { 512, CEIL_DIV(nr, 512), 1 };
} else {
elements = { nr, 1, 1 };
}
}
break;
case GGML_OP_RMS_NORM:
if (ctx->do_add_rms_partials) {
// Run one element per thread, 128 threads per workgroup
@@ -8938,6 +9137,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_TRI:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
@@ -9618,6 +9818,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
}
static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -10168,7 +10375,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup
uint32_t max_pipeline = num_topk_pipelines - 3;
uint32_t max_pipeline = num_topk_pipelines - 1;
uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
max_pipeline = std::min(preferred_pipeline, max_pipeline);
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
// require full subgroup
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
@@ -10260,6 +10469,21 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
}
static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
});
}
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int32_t s0 = dst->op_params[0];
const int32_t s1 = dst->op_params[1];
@@ -11726,6 +11950,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_LOG:
ggml_vk_log(ctx, compute_ctx, src0, node);
break;
case GGML_OP_TRI:
ggml_vk_tri(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CLAMP:
ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -11871,6 +12099,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_COUNT_EQUAL:
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_SOLVE_TRI:
ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_IM2COL:
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
@@ -13847,7 +14079,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_OPT_STEP_SGD:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_TRI:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->type == op->src[0]->type;
case GGML_OP_ARGSORT:
{
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
@@ -13916,6 +14150,25 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return false;
}
case GGML_OP_SOLVE_TRI:
{
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
return false;
}
const uint32_t N = op->src[0]->ne[0];
const uint32_t K = op->src[1]->ne[0];
// K dimension limited to workgroup size
if (K > 128) {
return false;
}
if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
return false;
}
return true;
}
case GGML_OP_ARGMAX:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_COUNT_EQUAL:
@@ -14419,6 +14672,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_LOG) {
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_TRI) {
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
} else if (tensor->op == GGML_OP_CLAMP) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
@@ -14588,6 +14843,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_SOLVE_TRI) {
tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);
} else if (tensor->op == GGML_OP_IM2COL) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
@@ -4,13 +4,6 @@
#include "types.glsl"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
@@ -156,7 +156,7 @@ void main() {
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
@@ -22,6 +22,13 @@ layout (push_constant) uniform parameter
#if !RMS_NORM_ROPE_FUSION
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#endif
@@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
uint get_idx() {
@@ -3,6 +3,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
#include "dequant_funcs.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
@@ -13,8 +13,6 @@
#include "mul_mat_vec_iface.glsl"
#include "dequant_funcs.glsl"
layout (push_constant) uniform parameter
{
uint ncols;
@@ -5,13 +5,15 @@
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_VEC4)
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
#endif
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
@@ -10,60 +10,56 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
#define K_PER_ITER 8
#include "mul_mmq_funcs.glsl"
#elif defined(DATA_A_QUANT_K)
#define K_PER_ITER 16
#else
#error unimplemented
#endif
uint a_offset, b_offset, d_offset;
int32_t cache_b_qs[2];
int32_t cache_b_qs[K_PER_ITER / 4];
vec2 cache_b_ds;
#include "mul_mat_vecq_funcs.glsl"
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
// Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % 4;
const uint b_qs_idx = tid % (32 / K_PER_ITER);
const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
#if QUANT_R == 2
// Assumes K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
#else
#if K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
#elif K_PER_ITER == 16
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
#else
#error unimplemented
#endif
#endif
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
ibi += p.ncols;
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
#if QUANT_AUXF == 1
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
#else
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
#endif
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
}
}
}
@@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
a_offset /= QUANT_K_Q8_1;
b_offset /= QUANT_K_Q8_1;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
@@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
@@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
@@ -0,0 +1,379 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#include "types.glsl"
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_dm(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_dm(uint ib) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
}
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
}
#endif
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
// 2-byte loads for Q4_0 blocks (18 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q4_1)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q5_0)
// 2-byte loads for Q5_0 blocks (22 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q5_1)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * da * dsb.x);
}
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])),
pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5);
}
#endif
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(ib_a, iqs);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(ib_a, iqs * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(ib_a, iqs * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
// 2 quants per call => divide sums by 8/2 = 4
return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4);
}
#endif
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t sum_d = 0;
int32_t sum_m = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const uint8_t scale = get_scale(ib_a, iqs * 4);
const vec2 dm = vec2(get_dm(ib_a));
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);
sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);
sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]);
sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m)));
}
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
const uint hm_shift = iqs_k / 8;
// bitwise OR to add 4 if hmask is set, subtract later
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2));
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)),
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)),
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4)));
}
float get_d_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 4;
const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4));
return float(data_a[ib_k].d) * float(scale - 32);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const float d_scale = get_d_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum));
}
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
#if defined(DATA_A_Q4_K)
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F;
return i32vec4(vals0, vals1, vals2, vals3);
#else // defined(DATA_A_Q5_K)
const uint qh_idx = iqs;
const uint qh_shift = iqs_k / 8;
return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4));
#endif
}
vec2 get_dm_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2));
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)),
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)),
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y)));
}
float get_d_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const float d_scale = get_d_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
}
#endif
@@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32;
#define BK 32
#define MMQ_SHMEM
#include "mul_mmq_shmem_types.glsl"
#ifdef MUL_MAT_ID
@@ -9,31 +9,6 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// 2-byte loads for Q4_0 blocks (18 bytes)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
#ifdef DATA_A_Q4_0
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#else // DATA_A_Q4_1
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#endif
}
#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#else // DATA_A_Q4_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q4_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
@@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
#ifdef DATA_A_Q4_0
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y)));
#else // DATA_A_Q4_1
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
#endif
}
#endif // MMQ_SHMEM
#endif
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
// 2-byte loads for Q5_0 blocks (22 bytes)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
#ifdef DATA_A_Q5_0
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
#else // DATA_A_Q5_1
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#else // DATA_A_Q5_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q5_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
@@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
#ifdef DATA_A_Q5_0
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y)));
#else // DATA_A_Q5_1
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
#endif
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
@@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, qs_b);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x));
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
return i32vec2( quants & 0x0F0F0F0F,
(quants >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * dsb.x * float(q_sum));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
@@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum));
}
#endif // MMQ_SHMEM
#endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
int32_t repack(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
@@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
}
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m)));
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint hm_idx = iqs * QUANT_R_MMQ;
@@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
return ACC_TYPE(float(cache_b.ds.x) * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
@@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
#endif
if (iqs == 0) {
// Scale index
const uint is = iqs_k / 8;
@@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
if (is_in_bounds) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
} else {
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
}
buf_b[buf_ib].qs[iqs * 4 ] = 0;
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
}
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
@@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
return ACC_TYPE(float(cache_b.ds.x) * result);
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
}
#endif
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
if (is_in_bounds) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
} else {
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
}
buf_b[buf_ib].qs[iqs * 4 ] = 0;
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
}
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#endif
@@ -0,0 +1,72 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout (constant_id = 1) const uint N = 64;
layout (constant_id = 2) const uint K = 32;
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
uint a_base, b_base, x_base;
FLOAT_TYPE get_a(uint r, uint c) {
return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]);
}
FLOAT_TYPE get_b(uint r, uint c) {
return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]);
}
void store_x(uint r, uint c, FLOAT_TYPE v) {
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
}
shared FLOAT_TYPE shA[N * N];
shared FLOAT_TYPE shB[N * K];
void main() {
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
if (batch >= p.ne02 * p.ne03) {
return;
}
const uint i3 = batch / p.ne22;
const uint i2 = batch % p.ne22;
a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03;
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
// Load the A matrix into shA
[[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
shA[idx] = get_a(idx / N, idx % N);
}
}
// Load the B matrix into shB
[[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
shB[idx] = get_b(idx / K, idx % K);
}
}
barrier();
FLOAT_TYPE X[N];
// Each thread solves one column
if (tid < K) {
[[unroll]] for (int r = 0; r < N; ++r) {
FLOAT_TYPE b = shB[r * K + tid];
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
[[unroll]] for (int c = 0; c < r; ++c) {
b -= shA[r * N + c] * X[c];
}
FLOAT_TYPE x = b / shA[r * N + r];
X[r] = x;
store_x(r, tid, x);
}
}
}
@@ -0,0 +1,43 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"
#define GGML_TRI_TYPE_UPPER_DIAG 0
#define GGML_TRI_TYPE_UPPER 1
#define GGML_TRI_TYPE_LOWER_DIAG 2
#define GGML_TRI_TYPE_LOWER 3
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
int param = floatBitsToInt(p.param1);
bool pass = false;
switch (param) {
case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break;
case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break;
}
if (pass) {
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
} else {
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
}
}
@@ -679,14 +679,20 @@ void process_shaders() {
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
// mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname)) {
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
}
#endif
@@ -846,6 +852,9 @@ void process_shaders() {
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -944,6 +953,8 @@ void process_shaders() {
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
for (auto transpose : {false, true}) {
for (auto unroll : {false, true}) {
for (auto a_f16 : {false, true}) {
@@ -1095,7 +1106,7 @@ void write_output_files() {
for (const std::string& btype : btypes) {
for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname)) {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
continue;
}
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
@@ -1104,6 +1115,16 @@ void write_output_files() {
src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
if (btype == "f16") {
continue;
}
hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";
hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n";
if (basename(input_filepath) == "mul_mat_vec.comp") {
src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
}
}
+33
View File
@@ -366,6 +366,7 @@ class MODEL_ARCH(IntEnum):
QWEN2VL = auto()
QWEN3 = auto()
QWEN3MOE = auto()
QWEN3NEXT = auto()
QWEN3VL = auto()
QWEN3VLMOE = auto()
PHI2 = auto()
@@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto()
SSM_BETA_ALPHA = auto() # qwen3next
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
@@ -736,6 +738,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.QWEN3: "qwen3",
MODEL_ARCH.QWEN3MOE: "qwen3moe",
MODEL_ARCH.QWEN3NEXT: "qwen3next",
MODEL_ARCH.QWEN3VL: "qwen3vl",
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
MODEL_ARCH.PHI2: "phi2",
@@ -900,6 +903,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
@@ -1569,6 +1573,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.QWEN3NEXT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_BETA_ALPHA,
MODEL_TENSOR.SSM_OUT
],
MODEL_ARCH.QWEN3VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
+12 -6
View File
@@ -371,10 +371,13 @@ class GGUFWriter:
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None
) -> None:
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# if tensor endianness is not passed, assume it's native to system
if tensor_endianess is None:
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
if tensor_endianess != self.endianess:
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
if self.use_temp_file and self.temp_file is None:
@@ -397,13 +400,16 @@ class GGUFWriter:
if pad != 0:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None:
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# if tensor endianness is not passed, assume it's native to system
if tensor_endianess is None:
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
if tensor_endianess != self.endianess:
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
+1 -1
View File
@@ -1552,7 +1552,7 @@ class GGUFEditorWindow(QMainWindow):
# Add tensors (including data)
for tensor in self.reader.tensors:
writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type)
writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess)
# Write header and metadata
writer.open_output_file(Path(file_path))
+1 -1
View File
@@ -94,7 +94,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
writer.write_ti_data_to_file()
for tensor in reader.tensors:
writer.write_tensor_data(tensor.data)
writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess)
bar.update(tensor.n_bytes)
writer.close()
+16 -6
View File
@@ -672,10 +672,11 @@ class TensorNameMap:
),
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
"model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
"model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next
),
MODEL_TENSOR.SSM_CONV1D: (
@@ -683,6 +684,7 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.conv1d", # mamba
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
"model.layers.{bid}.linear_attn.conv1d", # qwen3next
),
MODEL_TENSOR.SSM_X: (
@@ -697,6 +699,7 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.dt_proj", # mamba
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
),
MODEL_TENSOR.SSM_DT_NORM: (
@@ -709,6 +712,7 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.A_log", # mamba
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.A_log", # plamo2
"model.layers.{bid}.linear_attn.A_log", # qwen3next
),
MODEL_TENSOR.SSM_B_NORM: (
@@ -731,17 +735,23 @@ class TensorNameMap:
),
MODEL_TENSOR.SSM_NORM: (
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
"backbone.layers.{bid}.mixer.norm", # mamba2
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
"model.layers.{bid}.linear_attn.norm", # qwen3next
"backbone.layers.{bid}.mixer.norm", # mamba2
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", # mamba-hf
"backbone.layers.{bid}.mixer.out_proj", # mamba
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
"model.layers.{bid}.linear_attn.out_proj", # qwen3next
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
),
MODEL_TENSOR.SSM_BETA_ALPHA: (
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
),
MODEL_TENSOR.TIME_MIX_W0: (
"model.layers.{bid}.attention.w0", # rwkv7
),
+1
View File
@@ -114,6 +114,7 @@ add_library(llama
models/qwen3vl.cpp
models/qwen3vl-moe.cpp
models/qwen3moe.cpp
models/qwen3next.cpp
models/refact.cpp
models/rnd1.cpp
models/rwkv6-base.cpp
+48 -3
View File
@@ -32,6 +32,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_PHI2, "phi2" },
@@ -829,6 +830,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_QWEN3NEXT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_QWEN3VL,
{
@@ -2237,7 +2270,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name
{ LLM_TENSOR_OUTPUT, "output" },
}
},
@@ -2259,7 +2292,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
@@ -2487,11 +2520,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
},
};
// declare information about the model weight tensors:
// - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight
// - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator
//
// for example, input layers are usually assigned to CPU/host buffer types
//
// a mismatch between the declared information and the actual layer/op in which the tensor is used can lead to sub-optimal
// assignment of the buffer types and extra overhead during computation
// example: https://github.com/ggml-org/llama.cpp/pull/17548
//
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
@@ -2546,6 +2589,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@@ -2744,6 +2788,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_NEMOTRON_H:
case LLM_ARCH_QWEN3NEXT:
return true;
default:
return false;
+2
View File
@@ -36,6 +36,7 @@ enum llm_arch {
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_PHI2,
@@ -381,6 +382,7 @@ enum llm_tensor {
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,
LLM_TENSOR_TIME_MIX_W2,
+6 -2
View File
@@ -1,5 +1,6 @@
#include "llama-context.h"
#include "llama-arch.h"
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-io.h"
@@ -299,7 +300,7 @@ llama_context::llama_context(
cross.v_embd.clear();
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// avoid reserving graphs with zero outputs - assume one output per sequence
@@ -542,7 +543,7 @@ bool llama_context::memory_update(bool optimize) {
throw std::runtime_error("failed to initialize memory context");
}
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -1386,6 +1387,9 @@ void llama_context::output_reorder() {
//
uint32_t llama_context::graph_max_nodes() const {
if (model.arch == LLM_ARCH_QWEN3NEXT) {
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
}
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
}
+1 -1
View File
@@ -6,7 +6,7 @@
// bump if necessary
#define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
+102 -5
View File
@@ -2,7 +2,6 @@
#include "llama-impl.h"
#include "llama-mmap.h"
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-model-loader.h"
@@ -2225,6 +2224,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3NEXT:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
// Load linear attention (gated delta net) parameters
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Mark recurrent layers (linear attention layers)
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval"
}
switch (hparams.n_layer) {
case 80: type = LLM_TYPE_80B_A3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}
@@ -6133,9 +6155,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
@@ -6414,6 +6437,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_QWEN3NEXT:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
}
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
// Calculate dimensions from hyperparameters
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t head_v_dim = hparams.ssm_d_state;
const int64_t n_k_heads = hparams.ssm_n_group;
const int64_t n_v_heads = hparams.ssm_dt_rank;
const int64_t key_dim = head_k_dim * n_k_heads;
const int64_t value_dim = head_v_dim * n_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
// Calculate projection sizes
const int64_t qkvz_dim = key_dim * 2 + value_dim * 2;
const int64_t ba_dim = n_v_heads * 2;
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
if (!hparams.is_recurrent(i)) {
// Attention layers
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
// Q/K normalization for attention layers
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
} else {
// Linear attention (gated delta net) specific tensors
// Create tensors with calculated dimensions
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
}
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
// Shared experts
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -6684,6 +6775,7 @@ void llama_model::print_info() const {
arch == LLM_ARCH_FALCON_H1 ||
arch == LLM_ARCH_PLAMO2 ||
arch == LLM_ARCH_GRANITE_HYBRID ||
arch == LLM_ARCH_QWEN3NEXT ||
arch == LLM_ARCH_NEMOTRON_H) {
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
@@ -7425,7 +7517,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
case LLM_ARCH_PANGU_EMBED:
{
llm = std::make_unique<llm_build_pangu_embedded>(*this, params);
}break;
} break;
case LLM_ARCH_QWEN3NEXT:
{
llm = std::make_unique<llm_build_qwen3next>(*this, params);
} break;
default:
GGML_ABORT("fatal error");
}
@@ -7652,6 +7748,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_COGVLM:
case LLM_ARCH_PANGU_EMBED:
case LLM_ARCH_AFMOE:
case LLM_ARCH_QWEN3NEXT:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
+4
View File
@@ -113,6 +113,7 @@ enum llm_type {
LLM_TYPE_16B_A1B,
LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_30B_A3B,
LLM_TYPE_80B_A3B, // Qwen3 Next
LLM_TYPE_100B_A6B,
LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_230B_A10B, // Minimax M2
@@ -309,6 +310,9 @@ struct llama_layer {
struct ggml_tensor * ssm_conv1d_b = nullptr;
struct ggml_tensor * ssm_dt_b = nullptr;
// qwen3next
struct ggml_tensor * ssm_beta_alpha = nullptr;
// rwkv
struct ggml_tensor * time_mix_w1 = nullptr;
struct ggml_tensor * time_mix_w2 = nullptr;
+13 -5
View File
@@ -681,7 +681,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
continue;
} else if (remapped_name != it.first) {
}
if (remapped_name != it.first) {
ggml_set_name(it.second.tensor, remapped_name.c_str());
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
}
@@ -726,13 +728,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
{
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
// attention layers have a non-zero number of kv heads
int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
if (llama_model_has_encoder(&model)) {
// now n_attn_layer is the number of attention layers in the encoder
// now n_layer_attn is the number of attention layers in the encoder
// for each decoder block, there are 2 attention layers
n_attn_layer += 2 * model.hparams.dec_n_layer;
n_layer_attn += 2 * model.hparams.dec_n_layer;
}
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w);
GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
}
size_t total_size_org = 0;
+5 -3
View File
@@ -9,6 +9,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
ggml_tensor * cur = build_inp_embd(model.tok_embd);
cb(cur, "model.embed_tokens", -1);
ggml_build_forward_expand(gf, cur);
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_hybrid = build_inp_mem_hybrid();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -40,12 +42,12 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
cur = ggml_add(ctx0, cur, ffn_out);
}
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "model.embedding_norm", -1);
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cb(cur, "lm_head", -1);
cb(cur, "result_output", -1);
res->t_logits = cur;
+51 -1
View File
@@ -2,8 +2,9 @@
#include "../llama-model.h"
#include "../llama-graph.h"
#include "../llama-memory-recurrent.h"
// TODO: remove in follow-up PR - move to .cpp files
#include "../llama-memory-recurrent.h"
#include <cmath>
struct llm_graph_context_mamba : public llm_graph_context {
@@ -421,7 +422,56 @@ struct llm_build_qwen3vl : public llm_graph_context {
struct llm_build_qwen3vlmoe : public llm_graph_context {
llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_qwen3next : public llm_graph_context_mamba {
llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
private:
ggml_tensor * build_layer_attn(
llm_graph_input_attn_kv * inp_attn,
ggml_tensor * cur,
ggml_tensor * inp_pos,
int il);
ggml_tensor * build_layer_attn_linear(
llm_graph_input_rs * inp,
ggml_tensor * cur,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il);
ggml_tensor * build_layer_ffn(
ggml_tensor * cur,
int il);
ggml_tensor * build_delta_net_recurrent(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il);
ggml_tensor * build_delta_net_chunking(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il);
ggml_tensor * build_norm_gated(
ggml_tensor * input,
ggml_tensor * weights,
ggml_tensor * gate,
int layer);
const llama_model & model;
};
struct llm_build_qwen : public llm_graph_context {
llm_build_qwen(const llama_model & model, const llm_graph_params & params);
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -196,7 +196,7 @@ if (NOT WIN32)
llama_build_and_test(test-arg-parser.cpp)
endif()
if (NOT LLAMA_SANITIZE_ADDRESS)
if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC)
# TODO: repair known memory leaks
llama_build_and_test(test-opt.cpp)
endif()
+8 -3
View File
@@ -1446,14 +1446,14 @@ struct test_case {
const uint64_t target_flops_cpu = 8ULL * GFLOP;
const uint64_t target_flops_gpu = 100ULL * GFLOP;
uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
} else {
// based on memory size
const size_t GB = 1ULL << 30;
const size_t target_size_cpu = 8 * GB;
const size_t target_size_gpu = 32 * GB;
size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
}
// duplicate the op
@@ -7935,6 +7935,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
@@ -8040,7 +8043,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
for (auto k : {1, 10, 40}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1));
for (auto k : {1, 10, 40, 400}) {
for (auto nrows : {1, 16}) {
for (auto cols : {k, 1000, 65000, 200000}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
+26
View File
@@ -1339,6 +1339,32 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
test({
SUCCESS,
"literal string with escapes",
R"""({
"properties": {
"code": {
"const": " \r \n \" \\ ",
"description": "Generated code",
"title": "Code",
"type": "string"
}
},
"required": [
"code"
],
"title": "DecoderResponse",
"type": "object"
})""",
R"""(
code ::= "\" \\r \\n \\\" \\\\ \"" space
code-kv ::= "\"code\"" space ":" space code
root ::= "{" space code-kv "}" space
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
}
int main() {
+2
View File
@@ -21,6 +21,8 @@ set(TARGET_SRCS
server-queue.h
server-common.cpp
server-common.h
server-context.cpp
server-context.h
)
set(PUBLIC_ASSETS
index.html.gz
+72
View File
@@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
**Features:**
* LLM inference of F16 and quantized models on GPU and CPU
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
* [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) compatible chat completions
* Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510)
* Parallel decoding with multi-user support
* Continuous batching
@@ -1352,6 +1353,77 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
}'
```
### POST `/v1/messages`: Anthropic-compatible Messages API
Given a list of `messages`, returns the assistant's response. Streaming is supported via Server-Sent Events. While no strong claims of compatibility with the Anthropic API spec are made, in our experience it suffices to support many apps.
*Options:*
See [Anthropic Messages API documentation](https://docs.anthropic.com/en/api/messages). Tool use requires `--jinja` flag.
`model`: Model identifier (required)
`messages`: Array of message objects with `role` and `content` (required)
`max_tokens`: Maximum tokens to generate (default: 4096)
`system`: System prompt as string or array of content blocks
`temperature`: Sampling temperature 0-1 (default: 1.0)
`top_p`: Nucleus sampling (default: 1.0)
`top_k`: Top-k sampling
`stop_sequences`: Array of stop sequences
`stream`: Enable streaming (default: false)
`tools`: Array of tool definitions (requires `--jinja`)
`tool_choice`: Tool selection mode (`{"type": "auto"}`, `{"type": "any"}`, or `{"type": "tool", "name": "..."}`)
*Examples:*
```shell
curl http://localhost:8080/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: your-api-key" \
-d '{
"model": "gpt-4",
"max_tokens": 1024,
"system": "You are a helpful assistant.",
"messages": [
{"role": "user", "content": "Hello!"}
]
}'
```
### POST `/v1/messages/count_tokens`: Token Counting
Counts the number of tokens in a request without generating a response.
Accepts the same parameters as `/v1/messages`. The `max_tokens` parameter is not required.
*Example:*
```shell
curl http://localhost:8080/v1/messages/count_tokens \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello!"}
]
}'
```
*Response:*
```json
{"input_tokens": 10}
```
## More examples
### Interactive mode
@@ -257,9 +257,9 @@ const STRING_FORMAT_RULES = {
const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...STRING_FORMAT_RULES};
const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g;
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g;
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"\\]/g;
const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g;
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' };
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\' };
const NON_LITERAL_SET = new Set('|.()[]{}*+?');
const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?');
+240 -2
View File
@@ -725,7 +725,6 @@ std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtm
return result;
}
//
// OAI utils
//
@@ -1048,6 +1047,222 @@ json oaicompat_chat_params_parse(
return llama_params;
}
json convert_anthropic_to_oai(const json & body) {
json oai_body;
// Convert system prompt
json oai_messages = json::array();
auto system_param = json_value(body, "system", json());
if (!system_param.is_null()) {
std::string system_content;
if (system_param.is_string()) {
system_content = system_param.get<std::string>();
} else if (system_param.is_array()) {
for (const auto & block : system_param) {
if (json_value(block, "type", std::string()) == "text") {
system_content += json_value(block, "text", std::string());
}
}
}
oai_messages.push_back({
{"role", "system"},
{"content", system_content}
});
}
// Convert messages
if (!body.contains("messages")) {
throw std::runtime_error("'messages' is required");
}
const json & messages = body.at("messages");
if (messages.is_array()) {
for (const auto & msg : messages) {
std::string role = json_value(msg, "role", std::string());
if (!msg.contains("content")) {
if (role == "assistant") {
continue;
}
oai_messages.push_back(msg);
continue;
}
const json & content = msg.at("content");
if (content.is_string()) {
oai_messages.push_back(msg);
continue;
}
if (!content.is_array()) {
oai_messages.push_back(msg);
continue;
}
json tool_calls = json::array();
json converted_content = json::array();
json tool_results = json::array();
bool has_tool_calls = false;
for (const auto & block : content) {
std::string type = json_value(block, "type", std::string());
if (type == "text") {
converted_content.push_back(block);
} else if (type == "image") {
json source = json_value(block, "source", json::object());
std::string source_type = json_value(source, "type", std::string());
if (source_type == "base64") {
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
std::string data = json_value(source, "data", std::string());
std::ostringstream ss;
ss << "data:" << media_type << ";base64," << data;
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", ss.str()}
}}
});
} else if (source_type == "url") {
std::string url = json_value(source, "url", std::string());
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", url}
}}
});
}
} else if (type == "tool_use") {
tool_calls.push_back({
{"id", json_value(block, "id", std::string())},
{"type", "function"},
{"function", {
{"name", json_value(block, "name", std::string())},
{"arguments", json_value(block, "input", json::object()).dump()}
}}
});
has_tool_calls = true;
} else if (type == "tool_result") {
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
auto result_content = json_value(block, "content", json());
std::string result_text;
if (result_content.is_string()) {
result_text = result_content.get<std::string>();
} else if (result_content.is_array()) {
for (const auto & c : result_content) {
if (json_value(c, "type", std::string()) == "text") {
result_text += json_value(c, "text", std::string());
}
}
}
tool_results.push_back({
{"role", "tool"},
{"tool_call_id", tool_use_id},
{"content", result_text}
});
}
}
if (!converted_content.empty() || has_tool_calls) {
json new_msg = {{"role", role}};
if (!converted_content.empty()) {
new_msg["content"] = converted_content;
} else if (has_tool_calls) {
new_msg["content"] = "";
}
if (!tool_calls.empty()) {
new_msg["tool_calls"] = tool_calls;
}
oai_messages.push_back(new_msg);
}
for (const auto & tool_msg : tool_results) {
oai_messages.push_back(tool_msg);
}
}
}
oai_body["messages"] = oai_messages;
// Convert tools
if (body.contains("tools")) {
const json & tools = body.at("tools");
if (tools.is_array()) {
json oai_tools = json::array();
for (const auto & tool : tools) {
oai_tools.push_back({
{"type", "function"},
{"function", {
{"name", json_value(tool, "name", std::string())},
{"description", json_value(tool, "description", std::string())},
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
}}
});
}
oai_body["tools"] = oai_tools;
}
}
// Convert tool_choice
if (body.contains("tool_choice")) {
const json & tc = body.at("tool_choice");
if (tc.is_object()) {
std::string type = json_value(tc, "type", std::string());
if (type == "auto") {
oai_body["tool_choice"] = "auto";
} else if (type == "any" || type == "tool") {
oai_body["tool_choice"] = "required";
}
}
}
// Convert stop_sequences to stop
if (body.contains("stop_sequences")) {
oai_body["stop"] = body.at("stop_sequences");
}
// Handle max_tokens (required in Anthropic, but we're permissive)
if (body.contains("max_tokens")) {
oai_body["max_tokens"] = body.at("max_tokens");
} else {
oai_body["max_tokens"] = 4096;
}
// Pass through common params
for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) {
if (body.contains(key)) {
oai_body[key] = body.at(key);
}
}
// Handle Anthropic-specific thinking param
if (body.contains("thinking")) {
json thinking = json_value(body, "thinking", json::object());
std::string thinking_type = json_value(thinking, "type", std::string());
if (thinking_type == "enabled") {
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
oai_body["thinking_budget_tokens"] = budget_tokens;
}
}
// Handle Anthropic-specific metadata param
if (body.contains("metadata")) {
json metadata = json_value(body, "metadata", json::object());
std::string user_id = json_value(metadata, "user_id", std::string());
if (!user_id.empty()) {
oai_body["__metadata_user_id"] = user_id;
}
}
return oai_body;
}
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
json data = json::array();
int32_t n_tokens = 0;
@@ -1211,7 +1426,7 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
// format server-sent event (SSE), return the formatted string to send
// note: if data is a json array, it will be sent as multiple events, one per item
std::string format_sse(const json & data) {
std::string format_oai_sse(const json & data) {
std::ostringstream ss;
auto send_single = [&ss](const json & data) {
ss << "data: " <<
@@ -1230,6 +1445,29 @@ std::string format_sse(const json & data) {
return ss.str();
}
std::string format_anthropic_sse(const json & data) {
std::ostringstream ss;
auto send_event = [&ss](const json & event_obj) {
if (event_obj.contains("event") && event_obj.contains("data")) {
ss << "event: " << event_obj.at("event").get<std::string>() << "\n";
ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n";
} else {
ss << "data: " << safe_json_to_str(event_obj) << "\n\n";
}
};
if (data.is_array()) {
for (const auto & event : data) {
send_event(event);
}
} else {
send_event(data);
}
return ss.str();
}
bool is_valid_utf8(const std::string & str) {
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
const unsigned char* end = bytes + str.length();
+7 -1
View File
@@ -294,6 +294,9 @@ json oaicompat_chat_params_parse(
const oaicompat_parser_options & opt,
std::vector<raw_buffer> & out_files);
// convert Anthropic Messages API format to OpenAI Chat Completions API format
json convert_anthropic_to_oai(const json & body);
// TODO: move it to server-task.cpp
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
@@ -320,7 +323,10 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
// format server-sent event (SSE), return the formatted string to send
// note: if data is a json array, it will be sent as multiple events, one per item
std::string format_sse(const json & data);
std::string format_oai_sse(const json & data);
// format Anthropic-style SSE with event types
std::string format_anthropic_sse(const json & data);
bool is_valid_utf8(const std::string & str);
File diff suppressed because it is too large Load Diff
+83
View File
@@ -0,0 +1,83 @@
#include "server-http.h"
#include "server-task.h"
#include "server-queue.h"
#include <nlohmann/json_fwd.hpp>
#include <cstddef>
#include <memory>
struct server_context_impl; // private implementation
struct server_context {
std::unique_ptr<server_context_impl> impl;
server_context();
~server_context();
// initialize slots and server-related data
void init();
// load the model and initialize llama_context
// returns true on success
bool load_model(const common_params & params);
// this function will block main thread until termination
void start_loop();
// terminate main loop (will unblock start_loop)
void terminate();
// get the underlaying llama_context
llama_context * get_llama_context() const;
// get the underlaying queue_tasks and queue_results
// used by CLI application
std::pair<server_queue &, server_response &> get_queues();
};
// forward declarations
struct server_res_generator;
struct server_routes {
server_routes(const common_params & params, server_context & ctx_server, std::function<bool()> is_ready = []() { return true; })
: params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) {
init_routes();
}
void init_routes();
// handlers using lambda function, so that they can capture `this` without `std::bind`
server_http_context::handler_t get_health;
server_http_context::handler_t get_metrics;
server_http_context::handler_t get_slots;
server_http_context::handler_t post_slots;
server_http_context::handler_t get_props;
server_http_context::handler_t post_props;
server_http_context::handler_t get_api_show;
server_http_context::handler_t post_infill;
server_http_context::handler_t post_completions;
server_http_context::handler_t post_completions_oai;
server_http_context::handler_t post_chat_completions;
server_http_context::handler_t post_anthropic_messages;
server_http_context::handler_t post_anthropic_count_tokens;
server_http_context::handler_t post_apply_template;
server_http_context::handler_t get_models;
server_http_context::handler_t post_tokenize;
server_http_context::handler_t post_detokenize;
server_http_context::handler_t post_embeddings;
server_http_context::handler_t post_embeddings_oai;
server_http_context::handler_t post_rerank;
server_http_context::handler_t get_lora_adapters;
server_http_context::handler_t post_lora_adapters;
private:
// TODO: move these outside of server_routes?
std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot);
std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot);
std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot);
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type);
const common_params & params;
server_context_impl & ctx_server;
std::function<bool()> is_ready;
};
+14 -7
View File
@@ -136,15 +136,22 @@ bool server_http_context::init(const common_params & params) {
return true;
}
// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
// Check for API key in the Authorization header
std::string req_api_key = req.get_header_value("Authorization");
if (req_api_key.empty()) {
// retry with anthropic header
req_api_key = req.get_header_value("X-Api-Key");
}
// remove the "Bearer " prefix if needed
std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
return true; // API key is valid
}
if (req_api_key.substr(0, prefix.size()) == prefix) {
req_api_key = req_api_key.substr(prefix.size());
}
// validate the API key
if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
return true; // API key is valid
}
// API key is invalid or not provided
+84 -1
View File
@@ -199,7 +199,7 @@ server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{
if (!running) {
RES_DBG("%s : queue result stop\n", __func__);
RES_DBG("%s : queue result stop\n", "recv");
std::terminate(); // we cannot return here since the caller is HTTP code
}
return !queue_results.empty();
@@ -266,3 +266,86 @@ void server_response::terminate() {
running = false;
condition_results.notify_all();
}
//
// server_response_reader
//
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
id_tasks = server_task::get_list_id(tasks);
queue_results.add_waiting_tasks(tasks);
queue_tasks.post(std::move(tasks));
}
bool server_response_reader::has_next() const {
return !cancelled && received_count < id_tasks.size();
}
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
while (true) {
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
if (result == nullptr) {
// timeout, check stop condition
if (should_stop()) {
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
return nullptr;
}
} else {
if (result->is_error()) {
stop(); // cancel remaining tasks
SRV_DBG("%s", "received error result, stopping further processing\n");
return result;
}
if (result->is_stop()) {
received_count++;
}
return result;
}
}
// should not reach here
}
server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
batch_response batch_res;
batch_res.results.resize(id_tasks.size());
while (has_next()) {
auto res = next(should_stop);
if (res == nullptr) {
batch_res.is_terminated = true;
return batch_res;
}
if (res->is_error()) {
batch_res.error = std::move(res);
return batch_res;
}
const size_t idx = res->get_index();
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
batch_res.results[idx] = std::move(res);
}
return batch_res;
}
void server_response_reader::stop() {
queue_results.remove_waiting_task_ids(id_tasks);
if (has_next() && !cancelled) {
// if tasks is not finished yet, cancel them
cancelled = true;
std::vector<server_task> cancel_tasks;
cancel_tasks.reserve(id_tasks.size());
for (const auto & id_task : id_tasks) {
SRV_WRN("cancel task, id_task = %d\n", id_task);
server_task task(SERVER_TASK_TYPE_CANCEL);
task.id_target = id_task;
queue_results.remove_waiting_task_id(id_task);
cancel_tasks.push_back(std::move(task));
}
// push to beginning of the queue, so it has highest priority
queue_tasks.post(std::move(cancel_tasks), true);
} else {
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
}
}
+36
View File
@@ -108,3 +108,39 @@ public:
// terminate the waiting loop
void terminate();
};
// utility class to make working with server_queue and server_response easier
// it provides a generator-like API for server responses
// support pooling connection state and aggregating multiple results
struct server_response_reader {
std::unordered_set<int> id_tasks;
server_queue & queue_tasks;
server_response & queue_results;
size_t received_count = 0;
bool cancelled = false;
int polling_interval_seconds;
// should_stop function will be called each polling_interval_seconds
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
~server_response_reader() {
stop();
}
void post_tasks(std::vector<server_task> && tasks);
bool has_next() const;
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr next(const std::function<bool()> & should_stop);
struct batch_response {
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
std::vector<server_task_result_ptr> results;
server_task_result_ptr error; // nullptr if no error
};
// aggregate multiple results
batch_response wait_for_all(const std::function<bool()> & should_stop);
void stop();
};
+293 -11
View File
@@ -565,15 +565,17 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
// server_task_result_cmpl_final
//
json server_task_result_cmpl_final::to_json() {
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
switch (res_type) {
case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat();
case OAICOMPAT_TYPE_COMPLETION:
case TASK_RESPONSE_TYPE_OAI_CMPL:
return to_json_oaicompat();
case OAICOMPAT_TYPE_CHAT:
case TASK_RESPONSE_TYPE_OAI_CHAT:
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
case TASK_RESPONSE_TYPE_ANTHROPIC:
return stream ? to_json_anthropic_stream() : to_json_anthropic();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
GGML_ASSERT(false && "Invalid task_response_type");
}
}
@@ -768,19 +770,203 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
return deltas;
}
json server_task_result_cmpl_final::to_json_anthropic() {
std::string stop_reason = "max_tokens";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
}
json content_blocks = json::array();
common_chat_msg msg;
if (!oaicompat_msg.empty()) {
msg = oaicompat_msg;
} else {
msg.role = "assistant";
msg.content = content;
}
if (!msg.content.empty()) {
content_blocks.push_back({
{"type", "text"},
{"text", msg.content}
});
}
for (const auto & tool_call : msg.tool_calls) {
json tool_use_block = {
{"type", "tool_use"},
{"id", tool_call.id},
{"name", tool_call.name}
};
try {
tool_use_block["input"] = json::parse(tool_call.arguments);
} catch (const std::exception &) {
tool_use_block["input"] = json::object();
}
content_blocks.push_back(tool_use_block);
}
json res = {
{"id", oaicompat_cmpl_id},
{"type", "message"},
{"role", "assistant"},
{"content", content_blocks},
{"model", oaicompat_model},
{"stop_reason", stop_reason},
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
{"usage", {
{"input_tokens", n_prompt_tokens},
{"output_tokens", n_decoded}
}}
};
return res;
}
json server_task_result_cmpl_final::to_json_anthropic_stream() {
json events = json::array();
std::string stop_reason = "max_tokens";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
}
bool has_text = !oaicompat_msg.content.empty();
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
bool text_block_started = false;
std::unordered_set<size_t> tool_calls_started;
for (const auto & diff : oaicompat_msg_diffs) {
if (!diff.content_delta.empty()) {
if (!text_block_started) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", 0},
{"content_block", {
{"type", "text"},
{"text", ""}
}}
}}
});
text_block_started = true;
}
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", 0},
{"delta", {
{"type", "text_delta"},
{"text", diff.content_delta}
}}
}}
});
}
if (diff.tool_call_index != std::string::npos) {
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", content_block_index},
{"content_block", {
{"type", "tool_use"},
{"id", full_tool_call.id},
{"name", full_tool_call.name}
}}
}}
});
tool_calls_started.insert(diff.tool_call_index);
}
if (!diff.tool_call_delta.arguments.empty()) {
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", content_block_index},
{"delta", {
{"type", "input_json_delta"},
{"partial_json", diff.tool_call_delta.arguments}
}}
}}
});
}
}
}
if (has_text) {
events.push_back({
{"event", "content_block_stop"},
{"data", {
{"type", "content_block_stop"},
{"index", 0}
}}
});
}
for (size_t i = 0; i < num_tool_calls; i++) {
size_t content_block_index = (has_text ? 1 : 0) + i;
events.push_back({
{"event", "content_block_stop"},
{"data", {
{"type", "content_block_stop"},
{"index", content_block_index}
}}
});
}
events.push_back({
{"event", "message_delta"},
{"data", {
{"type", "message_delta"},
{"delta", {
{"stop_reason", stop_reason},
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}
}},
{"usage", {
{"output_tokens", n_decoded}
}}
}}
});
events.push_back({
{"event", "message_stop"},
{"data", {
{"type", "message_stop"}
}}
});
return events;
}
//
// server_task_result_cmpl_partial
//
json server_task_result_cmpl_partial::to_json() {
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
switch (res_type) {
case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat();
case OAICOMPAT_TYPE_COMPLETION:
case TASK_RESPONSE_TYPE_OAI_CMPL:
return to_json_oaicompat();
case OAICOMPAT_TYPE_CHAT:
case TASK_RESPONSE_TYPE_OAI_CHAT:
return to_json_oaicompat_chat();
case TASK_RESPONSE_TYPE_ANTHROPIC:
return to_json_anthropic();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
GGML_ASSERT(false && "Invalid task_response_type");
}
}
@@ -905,7 +1091,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
// server_task_result_embd
//
json server_task_result_embd::to_json() {
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
return res_type == TASK_RESPONSE_TYPE_OAI_EMBD
? to_json_oaicompat()
: to_json_non_oaicompat();
}
@@ -936,6 +1122,102 @@ json server_task_result_rerank::to_json() {
};
}
json server_task_result_cmpl_partial::to_json_anthropic() {
json events = json::array();
bool first = (n_decoded == 1);
static bool text_block_started = false;
if (first) {
text_block_started = false;
events.push_back({
{"event", "message_start"},
{"data", {
{"type", "message_start"},
{"message", {
{"id", oaicompat_cmpl_id},
{"type", "message"},
{"role", "assistant"},
{"content", json::array()},
{"model", oaicompat_model},
{"stop_reason", nullptr},
{"stop_sequence", nullptr},
{"usage", {
{"input_tokens", n_prompt_tokens},
{"output_tokens", 0}
}}
}}
}}
});
}
for (const auto & diff : oaicompat_msg_diffs) {
if (!diff.content_delta.empty()) {
if (!text_block_started) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", 0},
{"content_block", {
{"type", "text"},
{"text", ""}
}}
}}
});
text_block_started = true;
}
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", 0},
{"delta", {
{"type", "text_delta"},
{"text", diff.content_delta}
}}
}}
});
}
if (diff.tool_call_index != std::string::npos) {
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
if (!diff.tool_call_delta.name.empty()) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", content_block_index},
{"content_block", {
{"type", "tool_use"},
{"id", diff.tool_call_delta.id},
{"name", diff.tool_call_delta.name}
}}
}}
});
}
if (!diff.tool_call_delta.arguments.empty()) {
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", content_block_index},
{"delta", {
{"type", "input_json_delta"},
{"partial_json", diff.tool_call_delta.arguments}
}}
}}
});
}
}
}
return events;
}
//
// server_task_result_error
//
+27 -20
View File
@@ -27,11 +27,12 @@ enum server_task_type {
};
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
enum oaicompat_type {
OAICOMPAT_TYPE_NONE,
OAICOMPAT_TYPE_CHAT,
OAICOMPAT_TYPE_COMPLETION,
OAICOMPAT_TYPE_EMBEDDING,
enum task_response_type {
TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
TASK_RESPONSE_TYPE_OAI_CHAT,
TASK_RESPONSE_TYPE_OAI_CMPL,
TASK_RESPONSE_TYPE_OAI_EMBD,
TASK_RESPONSE_TYPE_ANTHROPIC,
};
enum stop_type {
@@ -66,9 +67,9 @@ struct task_params {
struct common_params_sampling sampling;
struct common_params_speculative speculative;
// OAI-compat fields
// response formatting
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
@@ -227,12 +228,12 @@ struct server_task_result_cmpl_final : server_task_result {
task_params generation_params;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
@@ -253,6 +254,10 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat();
json to_json_oaicompat_chat_stream();
json to_json_anthropic();
json to_json_anthropic_stream();
};
struct server_task_result_cmpl_partial : server_task_result {
@@ -270,11 +275,11 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings;
result_prompt_progress progress;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
@@ -292,6 +297,8 @@ struct server_task_result_cmpl_partial : server_task_result {
json to_json_oaicompat();
json to_json_oaicompat_chat();
json to_json_anthropic();
};
struct server_task_result_embd : server_task_result {
@@ -300,8 +307,8 @@ struct server_task_result_embd : server_task_result {
int32_t n_tokens;
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
// response formatting
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
virtual int get_index() override {
return index;
+10 -3612
View File
File diff suppressed because it is too large Load Diff
+6
View File
@@ -13,3 +13,9 @@ def stop_server_after_each_test():
) # copy the set to prevent 'Set changed size during iteration'
for server in instances:
server.stop()
@pytest.fixture(scope="module", autouse=True)
def do_something():
# this will be run once per test session, before any tests
ServerPreset.load_all()
-6
View File
@@ -5,12 +5,6 @@ from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="session", autouse=True)
def do_something():
# this will be run once per test session, before any tests
ServerPreset.load_all()
@pytest.fixture(autouse=True)
def create_server():
global server
@@ -0,0 +1,807 @@
#!/usr/bin/env python3
import pytest
import base64
import requests
from utils import *
server: ServerProcess
def get_test_image_base64() -> str:
"""Get a test image in base64 format"""
# Use the same test image as test_vision_api.py
IMG_URL = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
response = requests.get(IMG_URL)
response.raise_for_status()
return base64.b64encode(response.content).decode("utf-8")
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2-anthropic"
server.server_port = 8082
server.n_slots = 1
server.n_ctx = 8192
server.n_batch = 2048
@pytest.fixture
def vision_server():
"""Separate fixture for vision tests that require multimodal support"""
global server
server = ServerPreset.tinygemma3()
server.offline = False # Allow downloading the model
server.model_alias = "tinygemma3-anthropic"
server.server_port = 8083 # Different port to avoid conflicts
server.n_slots = 1
return server
# Basic message tests
def test_anthropic_messages_basic():
"""Test basic Anthropic messages endpoint"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": [
{"role": "user", "content": "Say hello"}
]
})
assert res.status_code == 200, f"Expected 200, got {res.status_code}"
assert res.body["type"] == "message", f"Expected type 'message', got {res.body.get('type')}"
assert res.body["role"] == "assistant", f"Expected role 'assistant', got {res.body.get('role')}"
assert "content" in res.body, "Missing 'content' field"
assert isinstance(res.body["content"], list), "Content should be an array"
assert len(res.body["content"]) > 0, "Content array should not be empty"
assert res.body["content"][0]["type"] == "text", "First content block should be text"
assert "text" in res.body["content"][0], "Text content block missing 'text' field"
assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}"
assert "usage" in res.body, "Missing 'usage' field"
assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens"
assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens"
assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer"
assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer"
assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens"
# Anthropic API should NOT include timings
assert "timings" not in res.body, "Anthropic API should not include timings field"
def test_anthropic_messages_with_system():
"""Test messages with system prompt"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"system": "You are a helpful assistant.",
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
assert len(res.body["content"]) > 0
def test_anthropic_messages_multipart_content():
"""Test messages with multipart content blocks"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What is"},
{"type": "text", "text": " the answer?"}
]
}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_messages_conversation():
"""Test multi-turn conversation"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
# Streaming tests
def test_anthropic_messages_streaming():
"""Test streaming messages"""
server.start()
res = server.make_stream_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 30,
"messages": [
{"role": "user", "content": "Say hello"}
],
"stream": True
})
events = []
for data in res:
# Each event should have type and other fields
assert "type" in data, f"Missing 'type' in event: {data}"
events.append(data)
# Verify event sequence
event_types = [e["type"] for e in events]
assert "message_start" in event_types, "Missing message_start event"
assert "content_block_start" in event_types, "Missing content_block_start event"
assert "content_block_delta" in event_types, "Missing content_block_delta event"
assert "content_block_stop" in event_types, "Missing content_block_stop event"
assert "message_delta" in event_types, "Missing message_delta event"
assert "message_stop" in event_types, "Missing message_stop event"
# Check message_start structure
message_start = next(e for e in events if e["type"] == "message_start")
assert "message" in message_start, "message_start missing 'message' field"
assert message_start["message"]["type"] == "message"
assert message_start["message"]["role"] == "assistant"
assert message_start["message"]["content"] == []
assert "usage" in message_start["message"]
assert message_start["message"]["usage"]["input_tokens"] > 0
# Check content_block_start
block_start = next(e for e in events if e["type"] == "content_block_start")
assert "index" in block_start, "content_block_start missing 'index'"
assert block_start["index"] == 0, "First content block should be at index 0"
assert "content_block" in block_start
assert block_start["content_block"]["type"] == "text"
# Check content_block_delta
deltas = [e for e in events if e["type"] == "content_block_delta"]
assert len(deltas) > 0, "Should have at least one content_block_delta"
for delta in deltas:
assert "index" in delta
assert "delta" in delta
assert delta["delta"]["type"] == "text_delta"
assert "text" in delta["delta"]
# Check content_block_stop
block_stop = next(e for e in events if e["type"] == "content_block_stop")
assert "index" in block_stop
assert block_stop["index"] == 0
# Check message_delta
message_delta = next(e for e in events if e["type"] == "message_delta")
assert "delta" in message_delta
assert "stop_reason" in message_delta["delta"]
assert message_delta["delta"]["stop_reason"] in ["end_turn", "max_tokens"]
assert "usage" in message_delta
assert message_delta["usage"]["output_tokens"] > 0
# Check message_stop
message_stop = next(e for e in events if e["type"] == "message_stop")
# message_stop should NOT have timings for Anthropic API
assert "timings" not in message_stop, "Anthropic streaming should not include timings"
# Token counting tests
def test_anthropic_count_tokens():
"""Test token counting endpoint"""
server.start()
res = server.make_request("POST", "/v1/messages/count_tokens", data={
"model": "test",
"messages": [
{"role": "user", "content": "Hello world"}
]
})
assert res.status_code == 200
assert "input_tokens" in res.body
assert isinstance(res.body["input_tokens"], int)
assert res.body["input_tokens"] > 0
# Should only have input_tokens, no other fields
assert "output_tokens" not in res.body
def test_anthropic_count_tokens_with_system():
"""Test token counting with system prompt"""
server.start()
res = server.make_request("POST", "/v1/messages/count_tokens", data={
"model": "test",
"system": "You are a helpful assistant.",
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["input_tokens"] > 0
def test_anthropic_count_tokens_no_max_tokens():
"""Test that count_tokens doesn't require max_tokens"""
server.start()
# max_tokens is NOT required for count_tokens
res = server.make_request("POST", "/v1/messages/count_tokens", data={
"model": "test",
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert "input_tokens" in res.body
# Tool use tests
def test_anthropic_tool_use_basic():
"""Test basic tool use"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 200,
"tools": [{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
}
},
"required": ["location"]
}
}],
"messages": [
{"role": "user", "content": "What's the weather in Paris?"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
assert len(res.body["content"]) > 0
# Check if model used the tool (it might not always, depending on the model)
content_types = [block.get("type") for block in res.body["content"]]
if "tool_use" in content_types:
# Model used the tool
assert res.body["stop_reason"] == "tool_use"
# Find the tool_use block
tool_block = next(b for b in res.body["content"] if b.get("type") == "tool_use")
assert "id" in tool_block
assert "name" in tool_block
assert tool_block["name"] == "get_weather"
assert "input" in tool_block
assert isinstance(tool_block["input"], dict)
def test_anthropic_tool_result():
"""Test sending tool results back
This test verifies that tool_result blocks are properly converted to
role="tool" messages internally. Without proper conversion, this would
fail with a 500 error: "unsupported content[].type" because tool_result
blocks would remain in the user message content array.
"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"messages": [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "test123",
"name": "get_weather",
"input": {"location": "Paris"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "test123",
"content": "The weather is sunny, 25°C"
}
]
}
]
})
# This would be 500 with the old bug where tool_result blocks weren't converted
assert res.status_code == 200
assert res.body["type"] == "message"
# Model should respond to the tool result
assert len(res.body["content"]) > 0
assert res.body["content"][0]["type"] == "text"
def test_anthropic_tool_result_with_text():
"""Test tool result mixed with text content
This tests the edge case where a user message contains both text and
tool_result blocks. The server must properly split these into separate
messages: a user message with text, followed by tool messages.
Without proper handling, this would fail with 500: "unsupported content[].type"
"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"messages": [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "tool_1",
"name": "get_weather",
"input": {"location": "Paris"}
}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "Here are the results:"},
{
"type": "tool_result",
"tool_use_id": "tool_1",
"content": "Sunny, 25°C"
}
]
}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
assert len(res.body["content"]) > 0
def test_anthropic_tool_result_error():
"""Test tool result with error flag"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"messages": [
{"role": "user", "content": "Get the weather"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "test123",
"name": "get_weather",
"input": {"location": "InvalidCity"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "test123",
"is_error": True,
"content": "City not found"
}
]
}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_tool_streaming():
"""Test streaming with tool use"""
server.jinja = True
server.start()
res = server.make_stream_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 200,
"stream": True,
"tools": [{
"name": "calculator",
"description": "Calculate math",
"input_schema": {
"type": "object",
"properties": {
"expression": {"type": "string"}
},
"required": ["expression"]
}
}],
"messages": [
{"role": "user", "content": "Calculate 2+2"}
]
})
events = []
for data in res:
events.append(data)
event_types = [e["type"] for e in events]
# Should have basic events
assert "message_start" in event_types
assert "message_stop" in event_types
# If tool was used, check for proper tool streaming
if any(e.get("type") == "content_block_start" and
e.get("content_block", {}).get("type") == "tool_use"
for e in events):
# Find tool use block start
tool_starts = [e for e in events if
e.get("type") == "content_block_start" and
e.get("content_block", {}).get("type") == "tool_use"]
assert len(tool_starts) > 0, "Should have tool_use content_block_start"
# Check index is correct (should be 0 if no text, 1 if there's text)
tool_start = tool_starts[0]
assert "index" in tool_start
assert tool_start["content_block"]["type"] == "tool_use"
assert "name" in tool_start["content_block"]
# Vision/multimodal tests
def test_anthropic_vision_format_accepted():
"""Test that Anthropic vision format is accepted (format validation only)"""
server.start()
# Small 1x1 red PNG image in base64
red_pixel_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 10,
"messages": [
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": red_pixel_png
}
},
{
"type": "text",
"text": "What is this?"
}
]
}
]
})
# Server accepts the format but tinyllama doesn't support images
# So it should return 500 with clear error message about missing mmproj
assert res.status_code == 500
assert "image input is not supported" in res.body.get("error", {}).get("message", "").lower()
def test_anthropic_vision_base64_with_multimodal_model(vision_server):
"""Test vision with base64 image using Anthropic format with multimodal model"""
global server
server = vision_server
server.start()
# Get test image in base64 format
image_base64 = get_test_image_base64()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 10,
"messages": [
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": image_base64
}
},
{
"type": "text",
"text": "What is this:\n"
}
]
}
]
})
assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}"
assert res.body["type"] == "message"
assert len(res.body["content"]) > 0
assert res.body["content"][0]["type"] == "text"
# The model should generate some response about the image
assert len(res.body["content"][0]["text"]) > 0
# Parameter tests
def test_anthropic_stop_sequences():
"""Test stop_sequences parameter"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Count to 10"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_temperature():
"""Test temperature parameter"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"temperature": 0.5,
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_top_p():
"""Test top_p parameter"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"top_p": 0.9,
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_top_k():
"""Test top_k parameter (llama.cpp specific)"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"top_k": 40,
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
# Error handling tests
def test_anthropic_missing_messages():
"""Test error when messages are missing"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50
# missing "messages" field
})
# Should return an error (400 or 500)
assert res.status_code >= 400
def test_anthropic_empty_messages():
"""Test permissive handling of empty messages array"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": []
})
# Server is permissive and accepts empty messages (provides defaults)
# This matches the permissive validation design choice
assert res.status_code == 200
assert res.body["type"] == "message"
# Content block index tests
def test_anthropic_streaming_content_block_indices():
"""Test that content block indices are correct in streaming"""
server.jinja = True
server.start()
# Request that might produce both text and tool use
res = server.make_stream_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 200,
"stream": True,
"tools": [{
"name": "test_tool",
"description": "A test tool",
"input_schema": {
"type": "object",
"properties": {
"param": {"type": "string"}
},
"required": ["param"]
}
}],
"messages": [
{"role": "user", "content": "Use the test tool"}
]
})
events = []
for data in res:
events.append(data)
# Check content_block_start events have sequential indices
block_starts = [e for e in events if e.get("type") == "content_block_start"]
if len(block_starts) > 1:
# If there are multiple blocks, indices should be sequential
indices = [e["index"] for e in block_starts]
expected_indices = list(range(len(block_starts)))
assert indices == expected_indices, f"Expected indices {expected_indices}, got {indices}"
# Check content_block_stop events match the starts
block_stops = [e for e in events if e.get("type") == "content_block_stop"]
start_indices = set(e["index"] for e in block_starts)
stop_indices = set(e["index"] for e in block_stops)
assert start_indices == stop_indices, "content_block_stop indices should match content_block_start indices"
# Extended features tests
def test_anthropic_thinking():
"""Test extended thinking parameter"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"thinking": {
"type": "enabled",
"budget_tokens": 50
},
"messages": [
{"role": "user", "content": "What is 2+2?"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_metadata():
"""Test metadata parameter"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"metadata": {
"user_id": "test_user_123"
},
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
# Compatibility tests
def test_anthropic_vs_openai_different_response_format():
"""Verify Anthropic format is different from OpenAI format"""
server.start()
# Make OpenAI request
openai_res = server.make_request("POST", "/v1/chat/completions", data={
"model": "test",
"max_tokens": 50,
"messages": [
{"role": "user", "content": "Hello"}
]
})
# Make Anthropic request
anthropic_res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert openai_res.status_code == 200
assert anthropic_res.status_code == 200
# OpenAI has "object", Anthropic has "type"
assert "object" in openai_res.body
assert "type" in anthropic_res.body
assert openai_res.body["object"] == "chat.completion"
assert anthropic_res.body["type"] == "message"
# OpenAI has "choices", Anthropic has "content"
assert "choices" in openai_res.body
assert "content" in anthropic_res.body
# Different usage field names
assert "prompt_tokens" in openai_res.body["usage"]
assert "input_tokens" in anthropic_res.body["usage"]
assert "completion_tokens" in openai_res.body["usage"]
assert "output_tokens" in anthropic_res.body["usage"]
+13
View File
@@ -49,6 +49,19 @@ def test_correct_api_key():
assert "content" in res.body
def test_correct_api_key_anthropic_header():
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"X-Api-Key": TEST_API_KEY,
})
assert res.status_code == 200
assert "error" not in res.body
assert "content" in res.body
def test_openai_library_correct_api_key():
global server
server.start()