Compare commits

...

38 Commits

Author SHA1 Message Date
uvos 5f91b1d5d5 ggml-cuda: gdn use shared mem for HIP (#20366)
Suggested-by: Aman Gupta <amangupta052@gmail.com>
2026-03-11 13:06:19 +08:00
uvos 9ef7523ee9 cuda/hip: fix loop unrolling in ssm-conv (#20369) 2026-03-11 13:04:32 +08:00
Pascal 00de615345 Fix agentic mcp image single model (#20339)
* webui: fix MCP image attachments dropped during the agentic loop in single-model mode

* chore: update webui build output
2026-03-11 05:31:33 +01:00
Alessandro de Oliveira Faria (A.K.A.CABELO) e1a399992b vendor : update cpp-httplib to 0.37.0 (#20207) 2026-03-11 11:03:53 +08:00
Alessandro de Oliveira Faria (A.K.A.CABELO) 4f2f0a163d vendor : update miniaudio to 0.11.25 (#20209) 2026-03-11 11:01:56 +08:00
Neo Zhang 0cec84f999 fix op rope, add rope_back (#20293) 2026-03-11 09:53:34 +08:00
Neo Zhang b2e1427c9b fix for failed UT case: ACC, L2_NORM, UPSCALE, fused_glu, unary (#20283) 2026-03-11 09:53:05 +08:00
Vinicios Lugli 4d99d45084 model : qwen3vl reranker text support (#20332)
* model : fix qwen3vl reranker support

* Remove CLS_OUT

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-10 23:40:14 +01:00
ddh0 10e5b148b0 llama-quant : correct n_attention_wv usage (#20357)
* llama-quant : correct `n_attention_wv` usage

In #19770, I introduced a regression in the way the
`quantize_state_impl` counter values were initialized. I was
incrementing and using `n_attention_wv` in the same loop, when it should
have been fixed by the time we're deciding tensor types in
`llama_tensor_get_type_impl` (for `use_more_bits`).

I never observed a difference in any of [my
tests](https://github.com/ggml-org/llama.cpp/pull/19770#issuecomment-4000424712)
- it was only after @bartowski kindly pointed this out that I realized
it was incorrect. (Thanks!)

* simplify
2026-03-10 21:43:29 +02:00
Georgi Gerganov 90b2731894 ggml : bump RPC version (#20330) 2026-03-10 21:36:57 +02:00
Reese Levine aa2d278a11 ggml webgpu: faster normal quant and some k-quant matrix operations, better shader parameter handling (#20173)
* K quant speedup (#20)

* Basic JIT compilation for mul_mat, get_rows, and scale (#17)

* scale jit working

* preliminary working jit for getrows and mulmat, needs refining

* simplified mul_mat preprocessing switch statement

* get_rows fixes, mul_mat refinement

* formatted + last edits

* removed some extraneous prints

* fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish

* small fix

* some changes, working

* get_rows and mul_mat jit fixed and working

* Update formatting

* formatting

* Add header

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on all-encompassing shader library

* refactor argmax, set_rows

* Refactor all but flashattention, mat mul

* no gibberish, all k quants added, merged

* vec memory fix

* q6_k matching metal on my machine, tests passing

* Set tile size for q6_k separately

* Separate out fast shaders

---------

Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>

* Move towards writeBuffer for params

* Move away from multiple buffers for set_rows errors, remove host buffer for parameter buffers, minor cleanups

* Remove extra file

* Formatting

---------

Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>
2026-03-10 09:14:27 -07:00
Piotr Wilkin (ilintar) 6c770d16ca Reduce level of content parser warning message to avoid log spam on non-debug verbosity (#20347) 2026-03-10 15:21:51 +01:00
Ray Xu 8d880ac012 examples : fix empty items in json_schema_to_grammar.py [no ci] (#19968)
* Fix logic for retrieving schema items in `json_schema_to_grammar.py`

If `schema['items']` is `{}` and `prefixItems not in schema', as `{}` is Falsy, the original code here will raise an error.

I think if `schema['items']` is `{}`, them items should just be `{}`

* Apply suggestion from @CISC

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Add tests for arrays with empty items

Add two unit tests to `tests/test-json-schema-to-grammar.cpp` that validate handling of arrays when 'items' is an empty schema and when 'prefixItems' is present alongside an empty 'items'. Both tests expect the same generated grammar, ensuring the JSON Schema->grammar conversion treats an empty 'items' schema (and the presence of 'prefixItems') correctly and covering this edge case.

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-10 14:38:18 +01:00
a3894281 0f1e9d14cc docs: update CPU backend ops to mark POOL_1D as supported (#20304) 2026-03-10 21:31:24 +08:00
Georgi Gerganov 1274fbee9e models : fix assert in mamba2 (cont) (#20335)
* models : fix assert in mamba2 (cont)

* cont : add n_group mod

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-10 15:00:08 +02:00
Georgi Gerganov a7b3dee7a5 server : make 2 checkpoints near the end of the prompt (#20288)
* server : make 2 checkpoints near the end of the prompt

* cont : adjust checkpoints
2026-03-10 14:28:23 +02:00
Sigbjørn Skjæret ec947d2b16 common : fix incorrect uses of stoul (#20313) 2026-03-10 11:40:26 +01:00
Charles Xu 0cd4f4720b kleidiai : support for concurrent sme and neon kernel execution (#20070) 2026-03-10 09:25:25 +02:00
Taimur Ahmad af237f3026 ggml-cpu: add RVV repack GEMM and GEMV for quantization types (#19121)
* ggml-cpu: add rvv ggml_quantize_mat_4x8 for q8_0

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

* ggml-cpu: add rvv repacking for iq4_nl

* ggml-cpu: add generic impl for iq4_nl gemm/gemv

* ggml-cpu: add rvv repacking for q8_0

* ggml-cpu: refactor; add rvv repacking for q4_0, q4_K

* ggml-cpu: refactor; add rvv repacking for q2_K

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>

* ggml-cpu: refactor rvv repack

---------

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
2026-03-10 08:49:52 +02:00
Julian Pscheid 1a5631beaa metal: handle command buffer failures gracefully in synchronize (#20306)
Replace GGML_ABORT("fatal error") in ggml_metal_synchronize() with
error flag + return. This aligns synchronize error handling with
graph_compute, which already returns GGML_STATUS_FAILED for the same
condition.

When a command buffer fails (e.g., iOS GPU access revocation during
backgrounding, macOS eGPU disconnect, OOM), the backend enters an
error state instead of killing the host process. Subsequent
graph_compute calls return GGML_STATUS_FAILED immediately. Recovery
requires recreating the backend.

Failed extra command buffers are properly released on the error path
to avoid Metal object leaks.
2026-03-10 08:32:24 +02:00
ddh0 1dab5f5a44 llama-quant : fail early on missing imatrix, refactor type selection, code cleanup (#19770)
* quantize : imatrix-fail early + code cleanup

* fix manual override printing

it's in the preliminary loop now, so needs to be on its own line

* revert header changes per ggerganov

* remove old #includes

* clarify naming

rename `tensor_quantization` to `tensor_typo_option` to descirbe its
functionality

* fix per barto
2026-03-10 08:16:05 +02:00
Aldehir Rojas c96f608d98 common: consolidate PEG string parsers (#20263)
* common : consolidate PEG string parsers
* cont : fix json_string_content()
2026-03-10 00:29:21 +01:00
Xuan-Son Nguyen 0842b9b465 model: fix step3.5 n_rot (#20318) 2026-03-09 23:42:24 +01:00
Xuan-Son Nguyen 59db9a357d llama: dynamic head_dim and n_rot for SWA (#20301)
* llama: dynamic head_dim and n_rot for SWA

* also add gguf_writer wrappers

* fix build

* build_rope_shift arg reorder
2026-03-09 22:22:39 +01:00
Evan Huus 23fbfcb1ad server: Parse port numbers from MCP server URLs in CORS proxy (#20208)
* Parse port numbers from MCP server URLs

* Pass scheme to http proxy for determining whether to use SSL

* Fix download on non-standard port and re-add port to logging

* add test

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-03-09 17:47:54 +01:00
Paul Flynn e22cd0aa15 metal : extend mul_mv_ext to BF16, Q2_K, Q3_K (#20250)
Enable mul_mv_ext small-batch kernels (BS 2-8) for BF16, Q2_K,
and Q3_K quantization types. These types previously fell through
to the slower single-row mul_mv path.

BF16 uses the float4 dequantize path (like F16). Q2_K and Q3_K
use the float4x4 K-quant path (like Q4_K/Q5_K/Q6_K).

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 16:48:12 +02:00
Georgi Gerganov 96cfc4992c server : fix checkpoints n_tokens calculation (#20287) 2026-03-09 16:47:06 +02:00
Georgi Gerganov ed0007aa32 metal : add upscale (#20284) 2026-03-09 16:45:11 +02:00
Georgi Gerganov 344ee2a38a server : warn swa-full is not supported for non-SWA models (#20291) 2026-03-09 16:44:25 +02:00
Georgi Gerganov d6e1556499 server : fix off-by-1 in server_tokens::size_up_to_pos() (#20279)
* server : fix off-by-1 in server_tokens::size_up_to_pos()

* cont : fix typo [no ci]
2026-03-09 16:43:38 +02:00
Piotr Wilkin (ilintar) f76565db92 common: map developer role to system (#20215)
* Map developer role to system
* Simplify
2026-03-09 14:25:11 +01:00
Georgi Gerganov 43e1cbd6c1 models : fix assert in mamba2 graph (#20270) 2026-03-09 13:15:15 +02:00
Georgi Gerganov 107d599952 server : add kill switch when server is stuck (#20277) 2026-03-09 10:33:12 +02:00
Aman Gupta e8bbc736cb ggml-cuda: disable gdn for musa (#20278) 2026-03-09 16:15:36 +08:00
ddh0 b518195101 llama-quant : left-align tensor names in output (#20117) 2026-03-09 09:28:41 +02:00
Aman Gupta e2763a6723 contributing: limit open PRs for new contributors to 1 (#20036) 2026-03-09 15:05:34 +08:00
Bertay Eren 0beb8db3a0 ggml-vulkan: add SGN operator, auto-generate Vulkan.csv and ops.md (#20219) 2026-03-09 07:24:16 +01:00
Ruben Ortlam b2f460bd3c vulkan: skip zero size tensors in backend copies (#20233) 2026-03-09 07:23:45 +01:00
182 changed files with 10285 additions and 22062 deletions
+1
View File
@@ -39,6 +39,7 @@ Before submitting your PR:
- For intricate features, consider opening a feature request first to discuss and align expectations
- When adding support for a new model or feature, focus on **CPU support only** in the initial PR unless you have a good reason not to. Add support for other backends like CUDA in follow-up PRs
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
- If you are a new contributor, limit your open PRs to 1.
After submitting your PR:
- Expect requests for modifications to ensure the code meets llama.cpp's standards for quality and long-term maintainability
+2 -2
View File
@@ -2427,11 +2427,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
);
}
if (split_arg.size() == 1) {
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024);
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoull(split_arg[0]) * 1024*1024);
return;
}
for (size_t i = 0; i < split_arg.size(); i++) {
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024;
params.fit_params_target[i] = std::stoull(split_arg[i]) * 1024*1024;
}
}
).set_env("LLAMA_ARG_FIT_TARGET"));
+1 -1
View File
@@ -90,7 +90,7 @@ common_peg_arena autoparser::build_parser(const templates_params & inputs) const
// pre-register a json-string rule that accepts both quote styles. This must happen
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
if (tools.format.uses_python_dicts) {
p.rule("json-string", [&]() { return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); });
p.rule("json-string", p.quoted_string());
}
parser_build_context ctx(p, inputs);
+8 -8
View File
@@ -507,8 +507,8 @@ common_peg_parser common_chat_peg_builder::python_style_tool_calls(
common_peg_parser arg_value_parser = eps();
auto string_value_parser = choice({
literal("\"") + tool_arg_string_value(json_string_content()) + literal("\""),
literal("'") + tool_arg_string_value(json_string_content()) + literal("'")
literal("\"") + tool_arg_string_value(string_content('"')) + literal("\""),
literal("'") + tool_arg_string_value(string_content('\'')) + literal("'")
});
if (is_string_type) {
@@ -577,7 +577,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
if (!call_id_key.empty()) {
auto id_parser = atomic(
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
literal("\"") + tool_id(json_string_content()) + literal("\"")
literal("\"") + tool_id(string_content('"')) + literal("\"")
);
inner_fields.push_back(optional(id_parser + space() + optional(literal(",") + space())));
}
@@ -586,7 +586,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
auto gen_id_parser = atomic(
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
literal("\"") + tool_id(string_content('"')) + literal("\""),
tool_id(json_number())
})
);
@@ -675,7 +675,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
if (id_spec.first.empty()) {
auto id_parser = atomic(
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
literal("\"") + tool_id(json_string_content()) + literal("\"")
literal("\"") + tool_id(string_content('"')) + literal("\"")
);
tool_parser_body = tool_parser_body + optional(id_parser + space() + literal(",") + space());
}
@@ -687,7 +687,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
auto gen_id_parser = atomic(
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
literal("\"") + tool_id(string_content('"')) + literal("\""),
tool_id(json_number())
})
);
@@ -736,7 +736,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
id_parser = atomic(
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
literal("\"") + tool_id(string_content('"')) + literal("\""),
tool_id(json_number())
})
);
@@ -747,7 +747,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
gen_id_parser = atomic(
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
choice({
literal("\"") + tool_id(json_string_content()) + literal("\""),
literal("\"") + tool_id(string_content('"')) + literal("\""),
tool_id(json_number())
})
);
+17 -2
View File
@@ -1352,6 +1352,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
namespace workaround {
static void map_developer_role_to_system(json & messages) {
for (auto & message : messages) {
if (message.contains("role")) {
if (message["role"] == "developer") {
message["role"] = "system";
}
}
}
}
// if first message is system and template does not support it, merge it with next message
static void system_message_not_supported(json & messages) {
if (!messages.empty() && messages.front().at("role") == "system") {
@@ -1429,6 +1440,10 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
params.add_bos = tmpls->add_bos;
params.add_eos = tmpls->add_eos;
if (src.find("<|channel|>") == std::string::npos) {
// map developer to system for all models except for GPT-OSS
workaround::map_developer_role_to_system(params.messages);
}
workaround::func_args_not_string(params.messages);
if (!tmpl.original_caps().supports_system_role) {
@@ -1605,8 +1620,8 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
build_chat_peg_parser([](common_chat_peg_builder & p) { return p.content(p.rest()) + p.end(); }) :
src_parser;
if (src_parser.empty()) {
LOG_WRN("No parser definition detected, assuming pure content parser.");
if (src_parser.empty()) {
LOG_DBG("No parser definition detected, assuming pure content parser.");
}
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
+16 -1
View File
@@ -7,6 +7,7 @@ struct common_http_url {
std::string user;
std::string password;
std::string host;
int port;
std::string path;
};
@@ -47,6 +48,20 @@ static common_http_url common_http_parse_url(const std::string & url) {
parts.host = rest;
parts.path = "/";
}
auto colon_pos = parts.host.find(':');
if (colon_pos != std::string::npos) {
parts.port = std::stoi(parts.host.substr(colon_pos + 1));
parts.host = parts.host.substr(0, colon_pos);
} else if (parts.scheme == "http") {
parts.port = 80;
} else if (parts.scheme == "https") {
parts.port = 443;
} else {
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
}
return parts;
}
@@ -68,7 +83,7 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
}
#endif
httplib::Client cli(parts.scheme + "://" + parts.host);
httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port));
if (!parts.user.empty()) {
cli.set_basic_auth(parts.user, parts.password);
+1 -1
View File
@@ -790,7 +790,7 @@ public:
} else if (target.is_array()) {
size_t sel_index;
try {
sel_index = std::stoul(sel);
sel_index = std::stoull(sel);
} catch (const std::invalid_argument & e) {
sel_index = target.size();
}
+119 -129
View File
@@ -658,7 +658,7 @@ struct parser_executor {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
}
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos) {
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos, const char delimiter) {
++pos; // consume '\'
if (pos >= ctx.input.size()) {
if (!ctx.is_lenient()) {
@@ -667,23 +667,14 @@ struct parser_executor {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos);
}
switch (ctx.input[pos]) {
case '"':
case '\'':
case '\\':
case '/':
case 'b':
case 'f':
case 'n':
case 'r':
case 't':
++pos;
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
case 'u':
return handle_unicode_escape(ctx, start, pos);
default:
// Invalid escape sequence
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
char c = ctx.input[pos];
if (c == delimiter || c == '\\' || c == '/' || c == 'b' || c == 'f' || c == 'n' || c == 'r' || c == 't') {
++pos;
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
} else if (c == 'u') {
return handle_unicode_escape(ctx, start, pos);
} else {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
}
}
@@ -704,62 +695,20 @@ struct parser_executor {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
}
common_peg_parse_result operator()(const common_peg_json_string_parser & /* p */) {
common_peg_parse_result operator()(const common_peg_string_parser & p) {
auto pos = start_pos;
// Parse string content (without quotes)
while (pos < ctx.input.size()) {
char c = ctx.input[pos];
if (c == '"') {
// Found closing quote - success (don't consume it)
if (c == p.delimiter) {
// Found closing delimiter - success (don't consume it)
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
}
if (c == '\\') {
auto result = handle_escape_sequence(ctx, start_pos, pos);
if (!result.success()) {
return result;
}
} else {
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
}
if (utf8_result.status == utf8_parse_result::INVALID) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
}
pos += utf8_result.bytes_consumed;
}
}
// Reached end without finding closing quote
if (!ctx.is_lenient()) {
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
}
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
}
common_peg_parse_result operator()(const common_peg_python_dict_string_parser & /* p */) {
auto pos = start_pos;
// Parse string content (without quotes)
while (pos < ctx.input.size()) {
char c = ctx.input[pos];
if (c == '\'') {
// Found closing quote - success (don't consume it)
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
}
if (c == '\\') {
auto result = handle_escape_sequence(ctx, start_pos, pos);
auto result = handle_escape_sequence(ctx, start_pos, pos, p.delimiter);
if (!result.success()) {
return result;
}
@@ -988,8 +937,7 @@ void common_peg_arena::resolve_refs() {
std::is_same_v<T, common_peg_ref_parser> ||
std::is_same_v<T, common_peg_until_parser> ||
std::is_same_v<T, common_peg_literal_parser> ||
std::is_same_v<T, common_peg_json_string_parser> ||
std::is_same_v<T, common_peg_python_dict_string_parser> ||
std::is_same_v<T, common_peg_string_parser> ||
std::is_same_v<T, common_peg_chars_parser> ||
std::is_same_v<T, common_peg_any_parser> ||
std::is_same_v<T, common_peg_space_parser>) {
@@ -1065,10 +1013,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", unbounded)";
}
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
return "JsonString()";
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
return "PythonDictString()";
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
return "String(" + std::string(1, p.delimiter) + ")";
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
return "Until(" + string_join(p.delimiters, " | ") + ")";
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
@@ -1281,47 +1227,25 @@ common_peg_arena common_peg_parser_builder::build() {
// String primitives
common_peg_parser common_peg_parser_builder::json_string_content() {
return wrap(arena_.add_parser(common_peg_json_string_parser{}));
}
common_peg_parser common_peg_parser_builder::single_quoted_string_content() {
return wrap(arena_.add_parser(common_peg_python_dict_string_parser{}));
common_peg_parser common_peg_parser_builder::string_content(char delimiter) {
return wrap(arena_.add_parser(common_peg_string_parser{delimiter}));
}
common_peg_parser common_peg_parser_builder::double_quoted_string() {
return rule("dq-string",
[this]() { return sequence({ literal("\""), json_string_content(), literal("\""), space() }); });
}
common_peg_parser common_peg_parser_builder::single_quoted_string() {
return rule("sq-string",
[this]() { return sequence({ literal("'"), single_quoted_string_content(), literal("'"), space() }); });
}
common_peg_parser common_peg_parser_builder::flexible_string() {
return rule("flexible-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
}
// Generic helpers for object/array structure
common_peg_parser common_peg_parser_builder::generic_object(const std::string & name,
const common_peg_parser & string_parser,
const common_peg_parser & value_parser) {
return rule(name, [this, string_parser, value_parser]() {
auto ws = space();
auto member = sequence({ string_parser, ws, literal(":"), ws, value_parser });
auto members = sequence({ member, zero_or_more(sequence({ ws, literal(","), ws, member })) });
return sequence({ literal("{"), ws, choice({ literal("}"), sequence({ members, ws, literal("}") }) }) });
return rule("double-quoted-string", [this]() {
return sequence({literal("\""), string_content('"'), literal("\""), space()});
});
}
common_peg_parser common_peg_parser_builder::generic_array(const std::string & name,
const common_peg_parser & value_parser) {
return rule(name, [this, value_parser]() {
auto ws = space();
auto elements = sequence({ value_parser, zero_or_more(sequence({ literal(","), ws, value_parser })) });
return sequence({ literal("["), ws, choice({ literal("]"), sequence({ elements, ws, literal("]") }) }) });
common_peg_parser common_peg_parser_builder::single_quoted_string() {
return rule("single-quoted-string", [this]() {
return sequence({literal("'"), string_content('\''), literal("'"), space()});
});
}
common_peg_parser common_peg_parser_builder::quoted_string() {
return rule("quoted-string", [this]() {
return choice({double_quoted_string(), single_quoted_string()});
});
}
@@ -1344,7 +1268,7 @@ common_peg_parser common_peg_parser_builder::json_number() {
common_peg_parser common_peg_parser_builder::json_string() {
return rule("json-string", [this]() {
return sequence({literal("\""), json_string_content(), literal("\""), space()});
return sequence({literal("\""), string_content('"'), literal("\""), space()});
});
}
@@ -1361,11 +1285,36 @@ common_peg_parser common_peg_parser_builder::json_null() {
}
common_peg_parser common_peg_parser_builder::json_object() {
return generic_object("json-object", json_string(), json());
return rule("json-object", [this]() {
auto ws = space();
auto member = sequence({json_string(), ws, literal(":"), ws, json()});
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
return sequence({
literal("{"),
ws,
choice({
literal("}"),
sequence({members, ws, literal("}")})
}),
ws
});
});
}
common_peg_parser common_peg_parser_builder::json_array() {
return generic_array("json-array", json());
return rule("json-array", [this]() {
auto ws = space();
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
return sequence({
literal("["),
ws,
choice({
literal("]"),
sequence({elements, ws, literal("]")})
}),
ws
});
});
}
common_peg_parser common_peg_parser_builder::json() {
@@ -1382,7 +1331,9 @@ common_peg_parser common_peg_parser_builder::json() {
}
common_peg_parser common_peg_parser_builder::python_string() {
return rule("python-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
return rule("python-string", [this]() {
return choice({double_quoted_string(), single_quoted_string()});
});
}
common_peg_parser common_peg_parser_builder::python_number() {
@@ -1390,24 +1341,63 @@ common_peg_parser common_peg_parser_builder::python_number() {
}
common_peg_parser common_peg_parser_builder::python_bool() {
return rule("python-bool", [this]() { return sequence({ choice({ literal("True"), literal("False") }), space() }); });
return rule("python-bool", [this]() {
return sequence({
choice({literal("True"), literal("False")}),
space()
});
});
}
common_peg_parser common_peg_parser_builder::python_null() {
return rule("python-none", [this]() { return sequence({ literal("None"), space() }); });
return rule("python-none", [this]() {
return sequence({literal("None"), space()});
});
}
common_peg_parser common_peg_parser_builder::python_dict() {
return generic_object("python-dict", python_string(), python_value());
return rule("python-dict", [this]() {
auto ws = space();
auto member = sequence({python_string(), ws, literal(":"), ws, python_value()});
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
return sequence({
literal("{"),
ws,
choice({
literal("}"),
sequence({members, ws, literal("}")})
}),
ws
});
});
}
common_peg_parser common_peg_parser_builder::python_array() {
return generic_array("python-array", python_value());
return rule("python-array", [this]() {
auto ws = space();
auto elements = sequence({python_value(), zero_or_more(sequence({literal(","), ws, python_value()}))});
return sequence({
literal("["),
ws,
choice({
literal("]"),
sequence({elements, ws, literal("]")})
}),
ws
});
});
}
common_peg_parser common_peg_parser_builder::python_value() {
return rule("python-value", [this]() {
return choice({ python_dict(), python_array(), python_string(), python_number(), python_bool(), python_null() });
return choice({
python_dict(),
python_array(),
python_string(),
python_number(),
python_bool(),
python_null()
});
});
}
@@ -1528,8 +1518,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
std::is_same_v<T, common_peg_chars_parser> ||
std::is_same_v<T, common_peg_space_parser> ||
std::is_same_v<T, common_peg_any_parser> ||
std::is_same_v<T, common_peg_json_string_parser> ||
std::is_same_v<T, common_peg_python_dict_string_parser>) {
std::is_same_v<T, common_peg_string_parser>) {
// These parsers do not have any children
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
for (auto child : p.children) {
@@ -1665,10 +1654,9 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
return result + "{" + std::to_string(p.min_count) + "}";
}
return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}";
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
const std::string delim(1, p.delimiter);
return R"(( [^)" + delim + R"(\\] | "\\" ( [)" + delim + R"(\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
if (p.delimiters.empty()) {
return ".*";
@@ -1798,10 +1786,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
{"min_count", p.min_count},
{"max_count", p.max_count}
};
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
return json{{"type", "json_string"}};
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
return json{{ "type", "python_dict_string" }};
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
return json{{"type", "string"}, {"delimiter", std::string(1, p.delimiter)}};
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
return json{{"type", "until"}, {"delimiters", p.delimiters}};
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
@@ -1928,11 +1914,15 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
}
return parser;
}
if (type == "json_string") {
return common_peg_json_string_parser{};
}
if (type == "python_dict_string") {
return common_peg_python_dict_string_parser{};
if (type == "string") {
if (!j.contains("delimiter")) {
throw std::runtime_error("string parser missing delimiter field.");
}
std::string delimiter = j["delimiter"];
if (delimiter.empty()) {
throw std::runtime_error("string parser delimiter is empty.");
}
return common_peg_string_parser{delimiter[0]};
}
if (type == "until") {
if (!j.contains("delimiters") || !j["delimiters"].is_array()) {
+7 -14
View File
@@ -231,8 +231,9 @@ struct common_peg_chars_parser {
int max_count; // -1 for unbounded
};
struct common_peg_json_string_parser {};
struct common_peg_python_dict_string_parser {};
struct common_peg_string_parser {
char delimiter;
};
struct common_peg_until_parser {
std::vector<std::string> delimiters;
@@ -280,8 +281,7 @@ using common_peg_parser_variant = std::variant<
common_peg_any_parser,
common_peg_space_parser,
common_peg_chars_parser,
common_peg_json_string_parser,
common_peg_python_dict_string_parser,
common_peg_string_parser,
common_peg_until_parser,
common_peg_schema_parser,
common_peg_rule_parser,
@@ -340,10 +340,6 @@ class common_peg_parser_builder {
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
// Generic helpers for building object/array structures with configurable string/value parsers.
common_peg_parser generic_object(const std::string & name, const common_peg_parser & string_parser, const common_peg_parser & value_parser);
common_peg_parser generic_array(const std::string & name, const common_peg_parser & value_parser);
public:
common_peg_parser_builder();
@@ -444,13 +440,10 @@ class common_peg_parser_builder {
common_peg_parser single_quoted_string();
// Matches a string that accepts both double-quoted and single-quoted styles.
common_peg_parser flexible_string();
common_peg_parser quoted_string();
// Matches double-quoted string content without the surrounding quotes.
common_peg_parser json_string_content();
// Matches single-quoted string content without the surrounding quotes.
common_peg_parser single_quoted_string_content();
// Matches string content without the surrounding delimiter.
common_peg_parser string_content(char delimiter);
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
// value -> object | array | string | number | true | false | null
+20 -4
View File
@@ -4390,15 +4390,31 @@ class Qwen3Model(Qwen2Model):
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
# a bit hacky, but currently the only way to detect if this is a rerank model
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
if self._is_qwen3_reranker():
self._find_rerank_config()
def _is_qwen3_reranker(self) -> bool:
readme_path = self.dir_model / "README.md"
readme_text = ""
if readme_path.exists():
with readme_path.open("r", encoding="utf-8") as f:
readme_text = f.read()
if "# Qwen3-Reranker" in readme_text:
self._find_rerank_config()
name_hints = [
str(self.dir_model.name),
str(self.hparams.get("_name_or_path", "")),
str(self.hparams.get("model_type", "")),
str(self.origin_hf_arch or ""),
]
name_hints = [hint.lower() for hint in name_hints if hint]
if "# qwen3-reranker" in readme_text.lower() or "# qwen3-vl-reranker" in readme_text.lower():
return True
if any("qwen3-reranker" in hint or "qwen3-vl-reranker" in hint for hint in name_hints):
return True
return "sequenceclassification" in (self.origin_hf_arch or "").lower()
def set_vocab(self):
# deal with intern-s1-mini
+7 -1
View File
@@ -599,7 +599,13 @@ If KleidiAI is enabled, the output will contain a line similar to:
```
load_tensors: CPU_KLEIDIAI model buffer size = 3474.00 MiB
```
KleidiAI's microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm and SME. llama.cpp selects the most efficient kernel based on runtime CPU feature detection. However, on platforms that support SME, you must manually enable SME microkernels by setting the environment variable `GGML_KLEIDIAI_SME=1`.
KleidiAIs microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm, SVE, and SME. Llama.cpp selects the most efficient kernels at runtime based on detected CPU capabilities.
On CPUs that support SME, SME microkernels are enabled automatically using runtime detection.
The environment variable GGML_KLEIDIAI_SME can be used to control SME behavior:
- Not set: enable SME automatically if supported and detected.
- 0: disable SME.
- <n> > 0: enable SME and assume <n> available SME units (override auto detection).
If SME is not supported by the CPU, SME microkernels are always disabled.
Depending on your build target, other higher priority backends may be enabled by default. To ensure the CPU backend is used, you must disable the higher priority backends either at compile time, e.g. -DGGML_METAL=OFF, or during run-time using the command line option `--device none`.
+10 -9
View File
@@ -23,7 +23,7 @@ Legend:
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
@@ -31,7 +31,7 @@ Legend:
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
@@ -47,6 +47,7 @@ Legend:
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -63,7 +64,7 @@ Legend:
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | ✅ | ✅ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
@@ -75,7 +76,7 @@ Legend:
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -85,24 +86,24 @@ Legend:
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | | ✅ | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | | 🟡 | ✅ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
+1689 -6836
View File
File diff suppressed because it is too large Load Diff
+1137 -12992
View File
File diff suppressed because it is too large Load Diff
+17 -4
View File
@@ -1,8 +1,8 @@
"backend_name","op_name","op_params","test_mode","supported","error_message","backend_reg_name"
"Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
@@ -85,8 +85,8 @@
"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
"Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
@@ -13591,3 +13591,16 @@
"Vulkan0","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","Vulkan"
"Vulkan0","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","0","no","Vulkan"
Can't render this file because it is too large.
+1 -1
View File
@@ -633,7 +633,7 @@ class SchemaConverter:
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
items = schema.get('items', schema.get('prefixItems'))
if isinstance(items, list):
return self._add_rule(
rule_name,
+6 -1
View File
@@ -8,7 +8,12 @@ extern "C" {
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 6
#define RPC_PROTO_PATCH_VERSION 0
#define RPC_PROTO_PATCH_VERSION 1
#ifdef __cplusplus
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
#endif
#define GGML_RPC_MAX_SERVERS 16
// backend API
+2 -1
View File
@@ -202,8 +202,9 @@
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
File diff suppressed because it is too large Load Diff
+3 -3
View File
@@ -520,7 +520,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .required_cpu = */ CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
@@ -631,7 +631,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .required_cpu = */ CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
@@ -801,7 +801,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .required_cpu = */ CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+56 -5
View File
@@ -28,13 +28,17 @@ template <int K, int N> struct block {
// control size
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding");
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding");
using block_q4_0x4 = block<4, 4>;
using block_q4_0x8 = block<4, 8>;
using block_q4_0x16 = block<4, 16>;
using block_q8_0x4 = block<8, 4>;
using block_q8_0x8 = block<8, 8>;
using block_q8_0x16 = block<8, 16>;
struct block_q4_Kx8 {
ggml_half d[8]; // super-block scale for quantized scales
@@ -44,7 +48,14 @@ struct block_q4_Kx8 {
};
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
struct block_q4_Kx16 {
ggml_half d[16]; // super-block scale for quantized scales
ggml_half dmin[16]; // super-block scale for quantized mins
uint8_t scales[192]; // scales and mins, quantized with 6 bits
uint8_t qs[2048]; // 4--bit quants
};
static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding");
struct block_q2_Kx8 {
ggml_half d[8]; // super-block scale for quantized scales
ggml_half dmin[8]; // super-block scale for quantized mins
@@ -53,6 +64,13 @@ struct block_q2_Kx8 {
};
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
struct block_q2_Kx16 {
ggml_half d[16]; // Super-block scale for quantized scales
ggml_half dmin[16]; // Super-block scale for quantized mins
uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks)
uint8_t qs[1024]; // Data (16 cols * 64 bytes per block)
};
static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding");
struct block_q5_Kx8 {
ggml_half d[8]; // super-block scale for quantized scales
@@ -97,6 +115,12 @@ struct block_iq4_nlx8 {
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
struct block_iq4_nlx16 {
ggml_half d[16]; // deltas for 16 iq4_nl blocks
uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks
};
static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding");
struct block_mxfp4x4 {
uint8_t e[4];
uint8_t qs[QK_MXFP4 * 2];
@@ -109,7 +133,6 @@ struct block_mxfp4x8 {
};
static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
#if defined(__cplusplus)
extern "C" {
#endif
@@ -132,6 +155,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -146,10 +171,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#if defined __riscv_zvfh
void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#endif
// Native implementations
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
@@ -170,6 +207,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -184,10 +223,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#if defined __riscv_zvfh
void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#endif
#if defined(__cplusplus)
} // extern "C"
+56 -29
View File
@@ -2,28 +2,29 @@
#include "ggml-cuda/common.cuh"
template <int S_v, bool KDA>
__global__ void gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t rq1,
int64_t rq3,
float scale) {
__global__ void __launch_bounds__(S_v, 1)
gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
const int64_t H,
const int64_t n_tokens,
const int64_t n_seqs,
const int64_t sq1,
const int64_t sq2,
const int64_t sq3,
const int64_t sv1,
const int64_t sv2,
const int64_t sv3,
const int64_t sb1,
const int64_t sb2,
const int64_t sb3,
const int64_t rq1,
const int64_t rq3,
const float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
@@ -40,8 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q,
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
// Load state column into registers
// GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
// TODO: check optimal path for RDNA1 and RDNA2 devices.
#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
extern __shared__ float s_shared[];
float * s = s_shared + col * S_v;
#else
float s[S_v];
#endif
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = curr_state[i * S_v + col];
@@ -114,6 +121,15 @@ __global__ void gated_delta_net_cuda(const float * q,
}
}
static size_t calculate_smem(const int sv, int cc)
{
size_t smem = 0;
if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
smem = sv * sv * sizeof(float);
}
return smem;
}
template <bool KDA>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
@@ -129,25 +145,36 @@ static void launch_gated_delta_net(
dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
switch (S_v) {
case 32:
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
case 32: {
constexpr int sv = 32;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
case 64:
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
}
case 64: {
constexpr int sv = 64;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
case 128:
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
}
case 128: {
constexpr int sv = 128;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
break;
}
default:
GGML_ABORT("fatal error");
break;
+7 -1
View File
@@ -4992,9 +4992,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_GATED_DELTA_NET:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_GATED_DELTA_NET:
//TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327
#ifdef GGML_USE_MUSA
return false;
#else
return true;
#endif // GGML_USE_MUSA
case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
case GGML_OP_CROSS_ENTROPY_LOSS:
+4 -1
View File
@@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
int row = tid / load_cols;
int col = tid % load_cols;
#pragma unroll
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
for (int idx = 0; idx < total_elems; idx += split_d_inner) {
if (row < (int)split_d_inner) {
smem[row * n_cols + col] = x_block[row * stride_x + col];
}
@@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
col += split_d_inner;
row += col / load_cols;
col = col % load_cols;
if (idx >= total_elems - tid - split_d_inner) {
break;
}
}
__syncthreads();
+22 -2
View File
@@ -75,6 +75,10 @@ struct ggml_metal {
// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
void * abort_callback_data;
// error state - set when a command buffer fails during synchronize
// once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated
bool has_error;
};
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
@@ -158,6 +162,8 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
res->capture_started = false;
res->capture_scope = nil;
res->has_error = false;
res->gf = nil;
res->encode_async = nil;
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
@@ -246,7 +252,8 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
if (status == MTLCommandBufferStatusError) {
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
GGML_ABORT("fatal error");
ctx->has_error = true;
return;
}
}
}
@@ -262,7 +269,15 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
if (status == MTLCommandBufferStatusError) {
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
GGML_ABORT("fatal error");
// release this and all remaining command buffers before returning
for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) {
[ctx->cmd_bufs_ext[j] release];
}
[ctx->cmd_bufs_ext removeAllObjects];
ctx->has_error = true;
return;
}
[cmd_buf release];
@@ -414,6 +429,11 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con
}
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
if (ctx->has_error) {
GGML_LOG_ERROR("%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\n", __func__);
return GGML_STATUS_FAILED;
}
// number of nodes encoded by the main thread (empirically determined)
const int n_main = MAX(64, 0.1*gf->n_nodes);
+20 -3
View File
@@ -1717,12 +1717,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met
char base[256];
char name[256];
snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
if (mode == GGML_SCALE_MODE_BILINEAR) {
snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type));
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type));
} else {
snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s_aa=%d", base, antialias);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
return res;
+1 -1
View File
@@ -1108,7 +1108,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->type == GGML_TYPE_F32 &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_POOL_1D:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_POOL_2D:
+2
View File
@@ -83,6 +83,7 @@
#define FC_UNARY 1200
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
#define FC_UPSCALE 1500
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -890,6 +891,7 @@ typedef struct {
float sf1;
float sf2;
float sf3;
float poffs;
} ggml_metal_kargs_upscale;
typedef struct {
+38 -24
View File
@@ -1963,6 +1963,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
(
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16 ||
op->src[0]->type == GGML_TYPE_Q4_0 ||
op->src[0]->type == GGML_TYPE_Q4_1 ||
op->src[0]->type == GGML_TYPE_Q5_0 ||
@@ -1977,6 +1978,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
op->src[0]->type == GGML_TYPE_Q4_K ||
op->src[0]->type == GGML_TYPE_Q5_K ||
op->src[0]->type == GGML_TYPE_Q6_K ||
op->src[0]->type == GGML_TYPE_Q2_K ||
op->src[0]->type == GGML_TYPE_Q3_K ||
false) && (ne11 >= 4 && ne11 <= 8)
)
)
@@ -3729,32 +3732,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float sf0 = (float)ne0/op->src[0]->ne[0];
const float sf1 = (float)ne1/op->src[0]->ne[1];
const float sf2 = (float)ne2/op->src[0]->ne[2];
const float sf3 = (float)ne3/op->src[0]->ne[3];
float sf0 = (float)ne0/op->src[0]->ne[0];
float sf1 = (float)ne1/op->src[0]->ne[1];
float sf2 = (float)ne2/op->src[0]->ne[2];
float sf3 = (float)ne3/op->src[0]->ne[3];
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
float poffs = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
poffs = 0.0f;
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
}
ggml_metal_kargs_upscale args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.sf0 =*/ sf0,
/*.sf1 =*/ sf1,
/*.sf2 =*/ sf2,
/*.sf3 =*/ sf3
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.sf0 =*/ sf0,
/*.sf1 =*/ sf1,
/*.sf2 =*/ sf2,
/*.sf3 =*/ sf3,
/*.poffs =*/ poffs,
};
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
+170 -1
View File
@@ -3481,6 +3481,13 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>;
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>;
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>;
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
#endif
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
@@ -3531,6 +3538,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
template<typename T0, typename T1, short NR0, typename args_t>
void kernel_mul_mv_t_t_impl(
args_t args,
@@ -4530,7 +4547,9 @@ kernel void kernel_conv_transpose_2d<half>(
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
kernel void kernel_upscale_f32(
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
kernel void kernel_upscale_nearest_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
@@ -4556,6 +4575,156 @@ kernel void kernel_upscale_f32(
}
}
static inline float bilinear_tri(float x) {
return MAX(0.0f, 1.0f - fabs(x));
}
kernel void kernel_upscale_bilinear_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
src0 += i03*args.nb03 + i02*args.nb02;
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (FC_upscale_aa) {
const float support0 = MAX(1.0f, 1.0f / args.sf0);
const float invscale0 = 1.0f / support0;
const float support1 = MAX(1.0f, 1.0f / args.sf1);
const float invscale1 = 1.0f / support1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
float sum = 0.0f;
float wsum = 0.0f;
for (int64_t sy = y_min; sy < y_max; ++sy) {
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
for (int64_t sx = x_min; sx < x_max; ++sx) {
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
const float w = wx * wy;
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
sum += (*src_ptr) * w;
wsum += w;
}
}
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
dst_ptr[i0] = v;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
const float v =
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
(*src10) * fd0 * (1.0f - fd1) +
(*src01) * (1.0f - fd0) * fd1 +
(*src11) * fd0 * fd1;
dst_ptr[i0] = v;
}
}
}
static inline float bicubic_weight1(float x) {
const float a = -0.75f;
return ((a + 2) * x - (a + 3)) * x * x + 1;
}
static inline float bicubic_weight2(float x) {
const float a = -0.75f;
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
kernel void kernel_upscale_bicubic_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = (int64_t)floor(f01);
const float fd1 = f01 - (float)i01;
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
const float w_y1 = bicubic_weight1(fd1);
const float w_y2 = bicubic_weight1(1.0f - fd1);
const float w_y3 = bicubic_weight2(2.0f - fd1);
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = (int64_t)floor(f00);
const float fd0 = f00 - (float)i00;
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
const float w_x1 = bicubic_weight1(fd0);
const float w_x2 = bicubic_weight1(1.0f - fd0);
const float w_x3 = bicubic_weight2(2.0f - fd0);
float sum = 0.0f;
for (int dy = -1; dy <= 2; ++dy) {
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
for (int dx = -1; dx <= 2; ++dx) {
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
sum += (*src_ptr) * wx * wy;
}
}
dst_ptr[i0] = sum;
}
}
kernel void kernel_pad_f32(
constant ggml_metal_kargs_pad & args,
device const char * src0,
+91
View File
@@ -874,4 +874,95 @@ static bool fast_fp16_available(const int cc) {
return true; //Intel GPUs always support FP16.
}
enum class block_reduce_method {
MAX,
SUM,
};
template<block_reduce_method method_t, typename T, int warp_size>
struct block_reduce_policy;
template <typename T, typename... Ts>
inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
template<typename...>
inline constexpr bool ggml_sycl_dependent_false_v = false;
#define WARP_32_SIZE 32
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::SUM, T, warp_size> {
static T reduce(T val) {
if constexpr (is_any<T, float, sycl::float2, sycl::half2, int>) {
return warp_reduce_sum<warp_size>(val);
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
static T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return 0.0f;
} else if constexpr (std::is_same_v<T, sycl::float2>) {
return sycl::float2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, sycl::half2>) {
return sycl::half2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, int>) {
return 0;
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
};
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::MAX, T, warp_size> {
static T reduce(T val) {
if constexpr (is_any<T, float, sycl::half2>) {
return warp_reduce_max<warp_size>(val);
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
static T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return -INFINITY;
} else if constexpr (std::is_same_v<T, sycl::half2>) {
return sycl::half2(-INFINITY, -INFINITY);
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
};
template <block_reduce_method reduce_method_t, int warp_size, typename T>
static T block_reduce(T val, T * shared_vals, int block_size_template) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
val = block_reduce_policy<reduce_method_t, T,warp_size>::reduce(val);
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
if (block_size > warp_size) {
assert((block_size <= 1024) && (block_size % warp_size) == 0);
const int warp_id = item_ct1.get_local_id(2) / warp_size;
const int lane_id = item_ct1.get_local_id(2) % warp_size;
if (lane_id == 0) {
shared_vals[warp_id] = val;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
size_t nreduce = nwarps / WARP_SIZE;
float tmp = 0.f;
if (lane_id < (static_cast<int>(block_size) / warp_size)) {
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += shared_vals[lane_id + i * WARP_SIZE];
}
}
return block_reduce_policy<reduce_method_t, T, warp_size>::reduce(tmp);
}
return val;
}
#endif // GGML_SYCL_COMMON_HPP
+6
View File
@@ -39,6 +39,11 @@ template<typename dst_t, typename src_t>
return sycl::ext::oneapi::bfloat16(float(x));
} else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
return static_cast<float>(x);
} else if constexpr (std::is_same_v<src_t, sycl::float2> && std::is_same_v<dst_t, sycl::half2>) {
return x.template convert<sycl::half, sycl::rounding_mode::rte>();
} else if constexpr (std::is_same_v<src_t, sycl::float2> &&
std::is_same_v<dst_t, sycl::vec<sycl::ext::oneapi::bfloat16, 2>>) {
return {x.x, x.y};
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
@@ -46,4 +51,5 @@ template<typename dst_t, typename src_t>
}
}
#endif // GGML_SYCL_CONVERT_HPP
+55 -58
View File
@@ -9,23 +9,32 @@
#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
(ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
static void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i = SYCL_LOCAL_ID_CALC(item_ct1, 2);
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
if (i >= ne) {
return;
}
int src1_idx = i - offset;
int oz = src1_idx / nb2;
int oy = (src1_idx - (oz * nb2)) / nb1;
int ox = src1_idx % nb1;
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
} else {
dst[i] = x[i];
int64_t src1_idx = i - offset;
int64_t tmp = src1_idx;
const int64_t i13 = tmp / s13;
tmp -= i13 * s13;
const int64_t i12 = tmp / s12;
tmp -= i12 * s12;
const int64_t i11 = tmp / s11;
tmp -= i11 * s11;
const int64_t i10 = tmp;
float val = x[i];
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
}
dst[i] = val;
}
/* Unary OP funcs */
@@ -364,18 +373,15 @@ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const
namespace ggml_sycl_detail {
static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int n_elements, const int ne10, const int ne11,
const int ne12, const int nb1, const int nb2,
const int offset, queue_ptr stream) {
int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) *
sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
item_ct1);
});
const int64_t n_elements, const int64_t ne10, const int64_t ne11,
const int64_t ne12, const int64_t ne13, const int64_t s1, const int64_t s2, const int64_t s3,
const int64_t offset, queue_ptr stream) {
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
});
}
template<typename T>
@@ -402,25 +408,19 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
@@ -434,14 +434,10 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const ggml_tensor * src0 = dst->src[0];
@@ -463,7 +459,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
GGML_ASSERT(src0->type == src1->type);
}
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
sycl::half * src0_p = (sycl::half *) src0_d;
@@ -484,7 +479,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
float * src0_p = (float *) src0_d;
@@ -513,13 +507,9 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
@@ -530,7 +520,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
@@ -539,7 +528,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
@@ -868,22 +856,31 @@ static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tens
}
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
const float * src1_dd = static_cast<const float*>(dst->src[1]->data);
float * dst_dd = static_cast<float *>(dst->data);
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
const int64_t s1 = dst->op_params[0] / sizeof(float);
const int64_t s2 = dst->op_params[1] / sizeof(float);
const int64_t s3 = dst->op_params[2] / sizeof(float);
const int64_t offset = dst->op_params[3] / sizeof(float);
ggml_sycl_detail::acc_f32_sycl(src0_d, src1_d, dst_d, ggml_nelements(dst),
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
s1, s2, s3, offset, stream);
}
static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+6 -1
View File
@@ -4145,6 +4145,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_ROPE:
ggml_sycl_rope(ctx, dst);
break;
case GGML_OP_ROPE_BACK:
ggml_sycl_rope_back(ctx, dst);
break;
case GGML_OP_IM2COL:
ggml_sycl_im2col(ctx, dst);
break;
@@ -4851,6 +4854,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return max_bias == 0.0f;
}
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_IM2COL:
return true;
case GGML_OP_UPSCALE:
@@ -4872,8 +4876,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
k > 0 && k <= 32;
}
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
case GGML_OP_ACC:
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
case GGML_OP_PAD:
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
if (ggml_get_op_params_i32(op, 8) != 0) {
+65 -63
View File
@@ -202,47 +202,34 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
}
}
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
template<int warp_size>
static void l2_norm_f32(const float * x, float * dst, const int ncols,
const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
const int nrows = item_ct1.get_group_range(2);
const int nchannels = item_ct1.get_group_range(1);
const int row = item_ct1.get_group(2);
const int channel = item_ct1.get_group(1);
const int sample = item_ct1.get_group(0);
const int tid = item_ct1.get_local_id(2);
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row * ncols + col];
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp, item_ct1);
if (block_size > WARP_SIZE) {
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
/*
DPCT1118:3: SYCL group functions and algorithms must be encountered in
converged control flow. You may need to adjust the code.
*/
item_ct1.barrier(sycl::access::fence_space::local_space);
size_t nreduce = nwarps / WARP_SIZE;
tmp = 0.f;
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += s_sum[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1);
}
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
tmp = block_reduce<block_reduce_method::SUM, warp_size>(tmp, s_sum, block_size);
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[row * ncols + col] = scale * x[row * ncols + col];
dst[col] = scale * x[col];
}
}
@@ -369,42 +356,50 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
}
}
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
template<int warp_size>
static void l2_norm_f32_sycl(const float * x,
float * dst,
const int ncols,
const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps,
queue_ptr stream,
int device) {
const dpct::dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
const dpct::dim3 block_dims(warp_size, 1, 1);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
sycl::nd_range<3>(blocks_num * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
nullptr, warp_size);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
assert(work_group_size % (warp_size * warp_size) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
int lsm_size = block_dims[2] > warp_size ? work_group_size / warp_size * sizeof(float): 0;
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(lsm_size),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
sycl::nd_range<3>(blocks_num * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
@@ -634,21 +629,28 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
}
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
/*support both WARP_SIZE or WARP_32_SIZE in code
choose by hardware for better performance
*/
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
}
+447 -283
View File
@@ -1,4 +1,5 @@
#include "rope.hpp"
#include "convert.hpp"
#include "ggml-sycl/common.hpp"
#include "ggml.h"
@@ -15,366 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
template <bool forward>
static void rope_yarn(const float theta_extrap, const float freq_scale,
const rope_corr_dims corr_dims, const int64_t i0,
const float ext_factor, float mscale, float &cos_theta,
float &sin_theta) {
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
float ramp_mix =
rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
}
*cos_theta = sycl::cos(theta) * mscale;
*sin_theta = sycl::sin(theta) * mscale;
cos_theta = sycl::cos(theta) * mscale;
sin_theta = sycl::sin(theta) * mscale;
if (!forward) {
sin_theta *= -1.0f;
}
}
template <typename T, bool has_ff>
static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
const sycl::nd_item<3> & item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
template <bool forward, bool has_ff, typename T, typename D>
static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const int64_t *row_indices, const int set_rows_stride) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int row0 = row % ne1;
const int channel0 = row / ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int i = row * ne0 + i0;
const int i2 = channel0 * s2 + row0 * s1 + i0;
int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
if (set_rows_stride != 0) {
idst = i1 * s1 + i0;
idst += row_indices[i2] * set_rows_stride;
}
const auto &store_coaelsced = [&](float x0, float x1) {
if constexpr (std::is_same_v<float, D>) {
sycl::float2 v = sycl::float2(x0, x1);
ggml_sycl_memcpy_1<8>(dst + idst, &v);
} else if constexpr (std::is_same_v<sycl::half, D>) {
sycl::half2 v = sycl::half2(x0, x1);
ggml_sycl_memcpy_1<4>(dst + idst, &v);
}
};
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
store_coaelsced(x[ix + 0], x[ix + 1]);
return;
}
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i2 + 0];
const float x1 = x[i2 + 1];
const float x0 = x[ix + 0];
const float x1 = x[ix + 1];
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
store_coaelsced(x0 * cos_theta - x1 * sin_theta,
x0 * sin_theta + x1 * cos_theta);
}
template <typename T, bool has_ff>
static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
const sycl::nd_item<3> & item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
template <bool forward, bool has_ff, typename T, typename D>
static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const int64_t *row_indices, const int set_rows_stride) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int row0 = row % ne1;
const int channel0 = row / ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int i = row * ne0 + i0 / 2;
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
if (set_rows_stride != 0) {
idst = i1 * s1 + i0 / 2;
idst += row_indices[i2] * set_rows_stride;
}
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
return;
}
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i2 + 0];
const float x1 = x[i2 + n_dims / 2];
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims / 2];
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
}
template <typename T, bool has_ff>
static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections,
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
// get index pos
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
if (i0 >= ne0) {
template <bool forward, bool has_ff, typename T>
static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const mrope_sections sections, const bool is_imrope) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (i0 >= ne00) {
return;
}
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = (row_dst * ne0) + (i0 / 2);
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
return;
}
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sect_dims =
sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
} else {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
}
}
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims/2];
// store results in dst
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims / 2];
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
}
template <bool forward, bool has_ff, typename T>
static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const mrope_sections sections) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
template <typename T, bool has_ff>
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections,
const sycl::nd_item<3> & item_ct1) {
// get index pos
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = (row_dst * ne0) + (i0 / 2);
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
const int sect_dims = sections.v[0] + sections.v[1];
const int sector = (i0 / 2) % sect_dims;
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
float theta_base = 0.0;
if (sector < sections.v[0]) {
const int p = sector;
theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
} else {
theta_base = pos[i2] * dpct::pow(theta_scale, p);
} else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
}
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims];
// store results in dst
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
}
template <typename T>
static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float * freq_factors, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> block_nums(1, num_blocks_x, nr);
template <bool forward, typename T, typename D>
static void
rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const int64_t *row_indices,
const int set_rows_stride, dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
if (freq_factors == nullptr) {
/*
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_norm<forward, false>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
} else {
/*
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_norm<forward, true>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
}
}
template <typename T>
static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> block_nums(1, num_blocks_x, nr);
template <bool forward, typename T, typename D>
static void
rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const int64_t *row_indices,
const int set_rows_stride, dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
if (freq_factors == nullptr) {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_neox<forward, false>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
} else {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_neox<forward, true>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
}
}
template <typename T>
static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
const float freq_scale, const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
template <bool forward, typename T>
static void
rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const mrope_sections sections,
const bool is_imrope, dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
// Add FP16 capability check if T could be sycl::half
if constexpr (std::is_same_v<T, sycl::half>) {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
}
// launch kernel
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_multi<forward, false, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections, is_imrope);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_multi<forward, true, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections, is_imrope);
});
}
}
template <bool forward, typename T>
static void
rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const mrope_sections sections,
dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
// rope vision
template <typename T>
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
const float freq_scale, const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
const mrope_sections sections, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
// Add FP16 capability check if T could be sycl::half
if constexpr (std::is_same_v<T, sycl::half>) {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
}
// launch kernel
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_vision<forward, false, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_vision<forward, true, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections);
});
}
}
inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
template <bool forward>
void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
const ggml_tensor *set_rows = nullptr) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const ggml_tensor *src2 = dst->src[2];
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(dst->src[0]->type == dst->type);
const int64_t ne00 = dst->src[0]->ne[0]; // head dims
const int64_t ne01 = dst->src[0]->ne[1]; // num heads
const int64_t ne02 = dst->src[0]->ne[2]; // num heads
const int64_t nr = ggml_nrows(dst->src[0]);
const float *src0_d = (const float *)src0->data;
const float *src1_d = (const float *)src1->data;
const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
void *dst_d = dst->data;
const int64_t *row_indices = nullptr;
ggml_type dst_type = dst->type;
int set_rows_stride = 0;
if (set_rows != nullptr) {
GGML_ASSERT(forward);
dst_d = set_rows->data;
row_indices = (const int64_t *)set_rows->src[1]->data;
dst_type = set_rows->type;
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
}
dpct::queue_ptr stream = ctx.stream();
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type ||
(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
const int n_dims = ((int32_t *)dst->op_params)[1];
const int mode = ((int32_t *)dst->op_params)[2];
const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
mrope_sections sections;
// RoPE alteration for extended context
float freq_base;
float freq_scale;
float ext_factor;
@@ -382,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
memcpy(&sections.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
@@ -396,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
sections.v[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
GGML_ASSERT(n_dims == ne00 / 2);
}
const int32_t * pos = (const int32_t *) dst->src[1]->data;
const int32_t *pos = (const int32_t *)src1_d;
const float * freq_factors = nullptr;
if (dst->src[2] != nullptr) {
freq_factors = (const float *) dst->src[2]->data;
const float *freq_factors = nullptr;
if (src2 != nullptr) {
freq_factors = (const float *)src2->data;
}
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
beta_slow, corr_dims.v);
// compute
if (is_neox) {
GGML_SYCL_DEBUG("%s: neox path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
main_stream);
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_neox_sycl<forward, float, float>(
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_neox_sycl<forward, float, sycl::half>(
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_neox_sycl<forward, sycl::half, sycl::half>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else if (is_mrope && !is_vision) {
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, sections, is_imrope, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
is_imrope, main_stream);
if (src0->type == GGML_TYPE_F32) {
rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
ne00, ne01, ne02, s01, s02, s03, s1, s2,
s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims,
freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_sycl<forward>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
sections, is_imrope, stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else if (is_vision) {
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, sections, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
main_stream);
if (src0->type == GGML_TYPE_F32) {
rope_vision_sycl<forward>(
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, sections,
stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_vision_sycl<forward>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
sections, stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else {
GGML_SYCL_DEBUG("%s: norm path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
main_stream);
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_sycl<forward, float, float>(
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_sycl<forward, float, sycl::half>(
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_sycl<forward, sycl::half, sycl::half>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
}
}
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
ggml_sycl_op_rope(ctx, dst);
ggml_sycl_op_rope_impl<true>(ctx, dst);
}
void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
ggml_sycl_op_rope_impl<false>(ctx, dst);
}
void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
ggml_tensor *set_rows) {
scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
}
+6
View File
@@ -15,6 +15,12 @@
#include "common.hpp"
#define SYCL_ROPE_BLOCK_SIZE 256
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
void ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
#endif // GGML_SYCL_ROPE_HPP
+39 -1
View File
@@ -763,6 +763,7 @@ struct vk_device_struct {
vk_pipeline pipeline_ceil[2];
vk_pipeline pipeline_floor[2];
vk_pipeline pipeline_trunc[2];
vk_pipeline pipeline_sgn[2];
vk_pipeline pipeline_add1_f16_f16;
vk_pipeline pipeline_add1_f16_f32;
@@ -4393,6 +4394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_UNARY(ceil)
CREATE_UNARY(floor)
CREATE_UNARY(trunc)
CREATE_UNARY(sgn)
#undef CREATE_UNARY
#define CREATE_UNARY_RTE(name) \
@@ -9281,6 +9283,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_TRUNC:
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_SGN:
return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16];
default:
break;
}
@@ -12875,6 +12879,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_TRUNC:
case GGML_UNARY_OP_SGN:
ggml_vk_unary(ctx, compute_ctx, src0, node);
break;
case GGML_UNARY_OP_XIELU:
@@ -13253,6 +13258,10 @@ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, g
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
vk_buffer buf = buf_ctx->dev_buffer;
if (size == 0) {
return;
}
uint32_t val32 = (uint32_t)value * 0x01010101;
ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
}
@@ -13262,6 +13271,10 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
vk_buffer buf = buf_ctx->dev_buffer;
if (size == 0) {
return;
}
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
@@ -13269,12 +13282,20 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
if (size == 0) {
return;
}
vk_buffer buf = buf_ctx->dev_buffer;
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
if (ggml_nbytes(src) == 0) {
return true;
}
if (ggml_backend_buffer_is_vk(src->buffer)) {
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
@@ -13464,6 +13485,10 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
if (size == 0) {
return;
}
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
vk_context cpy_ctx;
@@ -13507,6 +13532,10 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
if (size == 0) {
return;
}
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
@@ -13533,9 +13562,14 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
}
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
// Skip zero-size tensors
if (ggml_nbytes(src) == 0) {
return true;
}
if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) {
return false;
}
@@ -14975,6 +15009,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_TRUNC:
case GGML_UNARY_OP_SGN:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -16141,6 +16176,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_TRUNC:
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_SGN:
tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]);
break;
default:
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");
@@ -0,0 +1,21 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(sign(float(data_a[i])));
}
@@ -871,6 +871,8 @@ void process_shaders() {
string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sgn_f16", "sgn.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("sgn_f32", "sgn.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+59 -37
View File
@@ -42,11 +42,20 @@
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
// Requires at least two (and multiple of 2) k-quant blocks per tile
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
// default size for legacy matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256
@@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key {
bool src_overlap;
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
src_overlap == other.src_overlap;
}
};
@@ -749,6 +759,36 @@ class ggml_webgpu_shader_lib {
std::vector<std::string> defines;
std::string variant = "mul_mat_vec";
// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f16";
break;
default:
{
// Quantized types: use helpers but accumulate in f16
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
variant += "_" + src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back("SRC0_INNER_TYPE=f16");
break;
}
}
// src1 type (vector)
switch (context.src1->type) {
case GGML_TYPE_F32:
@@ -763,39 +803,21 @@ class ggml_webgpu_shader_lib {
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
}
// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
break;
default:
{
// Quantized types: use helpers but accumulate in f16
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back("SRC0_INNER_TYPE=f16");
break;
}
}
// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
if (key.src0_type >= GGML_TYPE_Q2_K) {
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
@@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.src_overlap = context.src_overlap,
};
+234 -225
View File
@@ -8,7 +8,6 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-webgpu-shader-lib.hpp"
#include "pre_wgsl.hpp"
#ifdef __EMSCRIPTEN__
# include <emscripten/emscripten.h>
@@ -20,12 +19,18 @@
#include <condition_variable>
#include <cstdint>
#include <cstring>
#include <iostream>
#ifdef GGML_WEBGPU_GPU_PROFILE
# include <iomanip>
#endif
#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
# include <iostream>
#endif
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
@@ -70,22 +75,21 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
#endif // GGML_WEBGPU_CPU_PROFILE
#ifdef GGML_WEBGPU_GPU_PROFILE
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
#endif
/* Constants */
#define WEBGPU_NUM_PARAM_BUFS 48u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
#define WEBGPU_NUM_PARAM_BUFS 96u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
// For operations which process a row in parallel, this seems like a reasonable
// default
@@ -118,14 +122,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
wgpu::BufferUsage usage,
const char * label);
struct webgpu_pool_bufs {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
};
// Holds a pool of parameter buffers for WebGPU operations
struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free;
std::vector<wgpu::Buffer> free;
// The pool must be synchronized because
// 1. The memset pool is shared globally by every ggml buffer,
@@ -138,7 +137,6 @@ struct webgpu_buf_pool {
size_t cur_pool_size;
size_t max_pool_size;
wgpu::Device device;
wgpu::BufferUsage host_buf_usage;
wgpu::BufferUsage dev_buf_usage;
size_t buf_size;
bool should_grow;
@@ -147,53 +145,47 @@ struct webgpu_buf_pool {
int num_bufs,
size_t buf_size,
wgpu::BufferUsage dev_buf_usage,
wgpu::BufferUsage host_buf_usage,
bool should_grow = false,
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
this->max_pool_size = max_pool_size;
this->cur_pool_size = num_bufs;
this->device = device;
this->host_buf_usage = host_buf_usage;
this->dev_buf_usage = dev_buf_usage;
this->buf_size = buf_size;
this->should_grow = should_grow;
this->max_pool_size = max_pool_size;
this->cur_pool_size = num_bufs;
this->device = device;
this->dev_buf_usage = dev_buf_usage;
this->buf_size = buf_size;
this->should_grow = should_grow;
for (int i = 0; i < num_bufs; i++) {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
free.push_back({ host_buf, dev_buf });
free.push_back(dev_buf);
}
}
webgpu_pool_bufs alloc_bufs() {
wgpu::Buffer alloc_bufs() {
std::unique_lock<std::mutex> lock(mutex);
if (!free.empty()) {
webgpu_pool_bufs bufs = free.back();
wgpu::Buffer buf = free.back();
free.pop_back();
return bufs;
return buf;
}
// Try growing the pool if no free buffers
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
cur_pool_size++;
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
if (!(host_buf && dev_buf)) {
if (!dev_buf) {
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
}
return webgpu_pool_bufs{ host_buf, dev_buf };
return dev_buf;
}
cv.wait(lock, [this] { return !free.empty(); });
webgpu_pool_bufs bufs = free.back();
wgpu::Buffer buf = free.back();
free.pop_back();
return bufs;
return buf;
}
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
void free_bufs(std::vector<wgpu::Buffer> bufs) {
std::lock_guard<std::mutex> lock(mutex);
free.insert(free.end(), bufs.begin(), bufs.end());
cv.notify_all();
@@ -201,12 +193,9 @@ struct webgpu_buf_pool {
void cleanup() {
std::lock_guard<std::mutex> lock(mutex);
for (auto & bufs : free) {
if (bufs.host_buf) {
bufs.host_buf.Destroy();
}
if (bufs.dev_buf) {
bufs.dev_buf.Destroy();
for (auto & buf : free) {
if (buf) {
buf.Destroy();
}
}
free.clear();
@@ -280,10 +269,9 @@ struct webgpu_gpu_profile_buf_pool {
#endif
struct webgpu_command {
uint32_t num_kernels;
wgpu::CommandBuffer commands;
std::vector<webgpu_pool_bufs> params_bufs;
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
uint32_t num_kernels;
wgpu::CommandBuffer commands;
std::vector<wgpu::Buffer> params_bufs;
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_gpu_profile_bufs timestamp_query_bufs;
std::string pipeline_name;
@@ -358,6 +346,13 @@ struct webgpu_global_context_struct {
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
struct webgpu_submission {
wgpu::FutureWaitInfo submit_done;
#ifdef GGML_WEBGPU_GPU_PROFILE
std::vector<wgpu::FutureWaitInfo> profile_futures;
#endif
};
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
// Points to global instances owned by ggml_backend_webgpu_reg_context
@@ -366,7 +361,8 @@ struct webgpu_context_struct {
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
webgpu_buf_pool param_buf_pool;
webgpu_buf_pool set_rows_error_buf_pool;
wgpu::Buffer set_rows_dev_error_buf;
wgpu::Buffer set_rows_host_error_buf;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
@@ -458,67 +454,105 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
/** End WebGPU object initializations */
/** WebGPU Actions */
static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
switch (status) {
case wgpu::WaitStatus::Success:
return true;
case wgpu::WaitStatus::TimedOut:
if (allow_timeout) {
return false;
}
GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
return false;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
return false;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
return false;
}
}
#ifdef GGML_WEBGPU_GPU_PROFILE
static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
futures.erase(std::remove_if(futures.begin(), futures.end(),
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
futures.end());
}
// Wait for the queue to finish processing all submitted work
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<wgpu::FutureWaitInfo> & futures,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first.
static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
std::vector<wgpu::FutureWaitInfo> & futures,
bool block) {
if (futures.empty()) {
return;
}
uint64_t timeout_ms = block ? UINT64_MAX : 0;
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
if (waitStatus == wgpu::WaitStatus::Error) {
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
if (block) {
while (!futures.empty()) {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
ggml_backend_webgpu_erase_completed_futures(futures);
}
}
if (futures[0].completed) {
futures.erase(futures.begin());
} else {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
ggml_backend_webgpu_erase_completed_futures(futures);
}
}
}
#endif
// Wait for the queue to finish processing all submitted work
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<webgpu_submission> & subs,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first.
if (subs.empty()) {
return;
}
while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
#endif
subs.erase(subs.begin());
}
}
if (futures.empty()) {
if (subs.empty()) {
return;
}
if (block) {
while (!futures.empty()) {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
break;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
break;
for (auto & sub : subs) {
while (!sub.submit_done.completed) {
auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
ggml_backend_webgpu_handle_wait_status(waitStatus);
}
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
#endif
}
subs.clear();
} else {
// Poll once and return
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::TimedOut:
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
break;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
break;
// Poll each submit future once and remove completed submissions.
for (auto sub = subs.begin(); sub != subs.end();) {
auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
ggml_backend_webgpu_handle_wait_status(waitStatus, true);
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
if (sub->submit_done.completed && sub->profile_futures.empty()) {
#else
if (sub->submit_done.completed) {
#endif
sub = subs.erase(sub);
} else {
++sub;
}
}
}
}
@@ -554,14 +588,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
}
#endif
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
webgpu_global_context ctx,
std::vector<webgpu_command> commands,
webgpu_buf_pool & param_buf_pool,
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
std::vector<webgpu_command> & commands,
webgpu_buf_pool & param_buf_pool) {
std::vector<wgpu::CommandBuffer> command_buffers;
std::vector<webgpu_pool_bufs> params_bufs;
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
std::vector<wgpu::Buffer> params_bufs;
webgpu_submission submission;
#ifdef GGML_WEBGPU_GPU_PROFILE
std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
#endif
@@ -569,14 +601,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
for (const auto & command : commands) {
command_buffers.push_back(command.commands);
params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
if (command.set_rows_error_bufs) {
set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
}
}
ctx->queue.Submit(command_buffers.size(), command_buffers.data());
std::vector<wgpu::FutureWaitInfo> futures;
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
@@ -586,27 +613,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
// Free the staged buffers
param_buf_pool.free_bufs(params_bufs);
});
futures.push_back({ p_f });
for (const auto & bufs : set_rows_error_bufs) {
wgpu::Future f = bufs.host_buf.MapAsync(
wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
[set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
} else {
const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
if (*error_data) {
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
}
// We can't unmap in here due to WebGPU reentrancy limitations.
if (set_rows_error_buf_pool) {
set_rows_error_buf_pool->free_bufs({ bufs });
}
}
});
futures.push_back({ f });
}
submission.submit_done = { p_f };
#ifdef GGML_WEBGPU_GPU_PROFILE
for (const auto & command : commands) {
@@ -623,14 +630,14 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
// WebGPU timestamps are in ns; convert to ms
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
ctx->shader_gpu_time_ms[label] += elapsed_ms;
// We can't unmap in here due to WebGPU reentrancy limitations.
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
}
// We can't unmap in here due to WebGPU reentrancy limitations.
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
});
futures.push_back({ f });
submission.profile_futures.push_back({ f });
}
#endif
return futures;
return submission;
}
static webgpu_command ggml_backend_webgpu_build_multi(
@@ -639,32 +646,21 @@ static webgpu_command ggml_backend_webgpu_build_multi(
const std::vector<webgpu_pipeline> & pipelines,
const std::vector<std::vector<uint32_t>> & params_list,
const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
GGML_ASSERT(pipelines.size() == params_list.size());
GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
GGML_ASSERT(pipelines.size() == workgroups_list.size());
std::vector<webgpu_pool_bufs> params_bufs_list;
std::vector<wgpu::BindGroup> bind_groups;
std::vector<wgpu::Buffer> params_bufs_list;
std::vector<wgpu::BindGroup> bind_groups;
for (size_t i = 0; i < pipelines.size(); i++) {
webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
params_bufs.host_buf.GetSize());
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
for (size_t j = 0; j < params_list[i].size(); j++) {
_params[j] = params_list[i][j];
}
params_bufs.host_buf.Unmap();
wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
uint32_t params_binding_num = entries.size();
entries.push_back({ .binding = params_binding_num,
.buffer = params_bufs.dev_buf,
.offset = 0,
.size = params_bufs.dev_buf.GetSize() });
entries.push_back(
{ .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
@@ -677,15 +673,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
}
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
for (const auto & params_bufs : params_bufs_list) {
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
}
// If there are SET_ROWS operations in this submission, copy their error
// buffers to the host.
if (set_rows_error_bufs) {
encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
set_rows_error_bufs->host_buf.GetSize());
for (size_t i = 0; i < params_bufs_list.size(); i++) {
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
}
#ifdef GGML_WEBGPU_GPU_PROFILE
@@ -718,7 +707,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
webgpu_command result = {};
result.commands = commands;
result.params_bufs = params_bufs_list;
result.set_rows_error_bufs = set_rows_error_bufs;
result.num_kernels = pipelines.size();
#ifdef GGML_WEBGPU_GPU_PROFILE
result.timestamp_query_bufs = ts_bufs;
@@ -734,13 +722,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
uint32_t wg_y = 1,
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
uint32_t wg_y = 1) {
return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
{
pipeline
},
{ params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
{ std::move(params) }, { std::move(bind_group_entries) },
{ { wg_x, wg_y } });
}
static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
@@ -757,8 +745,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
webgpu_command command =
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
ggml_backend_webgpu_wait(ctx, futures);
std::vector<webgpu_command> commands = { command };
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
ggml_backend_webgpu_wait(ctx, sub);
}
/** End WebGPU Actions */
@@ -805,7 +794,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
std::cout << "\nggml_webgpu: gpu breakdown:\n";
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
<< pct << "%)\n";
}
#endif
@@ -978,14 +968,6 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
if (decisions->i64_idx) {
error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
error_bufs->host_buf.Unmap();
}
}
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
@@ -1018,8 +1000,10 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
};
if (decisions->i64_idx) {
entries.push_back(
{ .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
entries.push_back({ .binding = 3,
.buffer = ctx->set_rows_dev_error_buf,
.offset = 0,
.size = ctx->set_rows_dev_error_buf.GetSize() });
}
uint32_t threads;
@@ -1029,8 +1013,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
}
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
error_bufs);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
}
// Workgroup size is a common constant
@@ -1108,12 +1091,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
use_fast = (src0->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32:
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q6_K:
use_fast = true;
break;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
use_fast = !is_vec;
break;
default:
break;
}
@@ -1187,17 +1184,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_fast && is_vec) {
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * batches;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else if (use_fast) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
// Fast-path tiled/subgroup calculations
uint32_t wg_m, wg_n;
uint32_t wg_m;
uint32_t wg_n;
if (decisions->use_subgroup_matrix) {
uint32_t wg_m_sg_tile =
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
@@ -1215,7 +1213,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else { // legacy
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_size = decisions->wg_size;
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
@@ -1514,10 +1512,10 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
}
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
uint32_t dim = (uint32_t) dst->op_params[0];
std::vector<uint32_t> params = {
@@ -1538,28 +1536,22 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t)src0->ne[dim]
(uint32_t) src0->ne[dim]
};
std::vector<wgpu::BindGroupEntry> entries = {
{
.binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0)
},
{
.binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1)
},
{
.binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst)
}
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {
@@ -1569,9 +1561,9 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@@ -1623,7 +1615,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -2172,19 +2169,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_SOFT_MAX:
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
case GGML_OP_UNARY:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_CLAMP:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_FILL:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_LOG:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQR:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQRT:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SIN:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_COS:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_PAD:
@@ -2192,7 +2182,6 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_ARGMAX:
return ggml_webgpu_argmax(ctx, src0, node);
case GGML_OP_ARGSORT:
return ggml_webgpu_argsort(ctx, src0, node);
case GGML_OP_TOP_K:
// we reuse the same argsort implementation for top_k
return ggml_webgpu_argsort(ctx, src0, node);
@@ -2214,33 +2203,51 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
std::vector<webgpu_command> commands;
std::vector<wgpu::FutureWaitInfo> futures;
uint32_t num_batched_kernels = 0;
std::vector<webgpu_command> commands;
std::vector<webgpu_submission> subs;
uint32_t num_batched_kernels = 0;
bool contains_set_rows = false;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
contains_set_rows = true;
}
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
commands.push_back(*cmd);
num_batched_kernels += cmd.value().num_kernels;
}
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
num_batched_kernels = 0;
std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
num_batched_kernels = 0;
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
// Process events and check for completed submissions
ctx->global_ctx->instance.ProcessEvents();
ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
commands.clear();
}
}
if (!commands.empty()) {
auto new_futures =
ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.insert(futures.end(), new_futures.begin(), new_futures.end());
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
commands.clear();
}
ggml_backend_webgpu_wait(ctx->global_ctx, futures);
// If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
if (contains_set_rows) {
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
ctx->set_rows_host_error_buf.GetSize());
wgpu::CommandBuffer set_rows_commands = encoder.Finish();
ctx->global_ctx->queue.Submit(1, &set_rows_commands);
ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
ctx->set_rows_host_error_buf.GetSize());
const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
if (*error_data) {
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
}
ctx->set_rows_host_error_buf.Unmap();
}
ggml_backend_webgpu_wait(ctx->global_ctx, subs);
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
return GGML_STATUS_SUCCESS;
}
@@ -2859,10 +2866,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
@@ -11,7 +11,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
#endif
#endif // VEC
#ifdef SCALAR
#define VEC_SIZE 1
@@ -23,7 +23,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val;
}
#endif
#endif // SCALAR
#ifdef INIT_SRC0_SHMEM_FLOAT
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
@@ -40,7 +40,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
}
}
#endif
#endif // INIT_SRC0_SHMEM_FLOAT
#ifdef INIT_SRC1_SHMEM_FLOAT
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
@@ -57,7 +57,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
}
}
#endif
#endif // INIT_SRC1_SHMEM_FLOAT
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
@@ -100,4 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
}
}
#endif
#endif // INIT_SRC0_SHMEM_Q4_0
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
shmem[shmem_idx + j * 2 + k] = q_lo;
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q4_1
#ifdef INIT_SRC0_SHMEM_Q5_0
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_0
#ifdef INIT_SRC0_SHMEM_Q5_1
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_1
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
shmem[shmem_idx + j * 2 + k] = q_val;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_0
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d + m;
shmem[shmem_idx + j * 2 + k] = q_val;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_1
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 42u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx + 40u];
let dmin = src0[scale_idx + 41u];
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
let pos_in_32 = k_in_block % 32u;
let q_b_idx = (block_of_32 / 4u) * 32u;
let shift = (block_of_32 % 4u) * 2u;
let k = (pos_in_32 / 16u) * 16u;
let l = pos_in_32 % 16u;
let is = k_in_block / 16u;
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q2_K
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 55u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx + 54u];
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
let kmask2: u32 = 0x0f0f0f0fu;
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
let scale_0 = src0[scale_idx + 48u + (2u*i)];
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
}
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
let hmask_0 = src0[scale_idx + (2u*i)];
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
let qs_0 = src0[scale_idx + 16u + (2u*i)];
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
}
let half = k_in_block / 128u; // 0 or 1
let pos_in_half = k_in_block % 128u; // 0-127
let shift_group = pos_in_half / 32u; // 0-3
let pos_in_32 = pos_in_half % 32u; // 0-31
let k_group = pos_in_32 / 16u; // 0 or 1
let l = pos_in_32 % 16u; // 0-15
let q_b_idx = half * 32u; // 0 or 32
let shift = shift_group * 2u; // 0, 2, 4, 6
let k = k_group * 16u; // 0 or 16
let is = k_in_block / 16u; // 0-15
// m increments every 32 elements across entire 256 element block
let m_shift = k_in_block / 32u; // 0-7
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
let sc = get_byte(scale_vals[is / 4u], is % 4u);
let dl = d * (f16(sc) - 32.0);
let q_idx = q_b_idx + k + l;
let hm_idx = k + l;
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
let qs_val = (q_byte >> shift) & 3u;
let q_val = (f16(qs_val) - f16(hm)) * dl;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q3_K
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 72u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
}
// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
// Inner loop over 2 shifts per group
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
var sc: u32;
var mn: u32;
if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q4_K
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 88u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
}
// The original loop processes elements in groups of 64
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
// But u increments EVERY 32 elements (after each l loop)
let group_of_64 = k_in_block / 64u; // 0-3
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
let u_shift = k_in_block / 32u; // 0-7
let u: u32 = 1u << u_shift;
var sc: u32;
var mn: u32;
if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
let qh_byte = get_byte(qh_packed, l % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q5_K
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
let quarter = pos_in_half / 32u;
let l = pos_in_half % 32u;
let ql_b_idx = half * 64u;
let qh_b_idx = half * 32u;
let sc_b_idx = half * 8u;
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13_word = ql13_flat / 4u;
let ql13 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql13_word],
src0[scale_idx + 2u * ql13_word + 1u]
));
let ql13_b = get_byte(ql13, ql13_flat % 4u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24_word = ql24_flat / 4u;
let ql24 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql24_word],
src0[scale_idx + 2u * ql24_word + 1u]
));
let ql24_b = get_byte(ql24, ql24_flat % 4u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh_word = qh_flat / 4u;
let qh = bitcast<u32>(vec2(
src0[scale_idx + 64u + 2u * qh_word],
src0[scale_idx + 64u + 2u * qh_word + 1u]
));
let qh_b = get_byte(qh, qh_flat % 4u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc_word = sc_idx / 4u;
let sc = bitcast<u32>(vec2(
src0[scale_idx + 96u + 2u * sc_word],
src0[scale_idx + 96u + 2u * sc_word + 1u]
));
let sc_val = get_byte_i32(sc, sc_idx % 4u);
let d = src0[scale_idx + 104u];
var q_val: f16;
if (quarter == 0u) {
q_val = q1;
} else if (quarter == 1u) {
q_val = q2;
} else if (quarter == 2u) {
q_val = q3;
} else {
q_val = q4;
}
shmem[elem_idx] = d * f16(sc_val) * q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q6_K
@@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 {
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
@@ -1,4 +1,3 @@
enable f16;
#include "common_decls.tmpl"
@@ -84,6 +83,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
}
#endif
#ifdef MUL_ACC_Q4_1
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 10u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = f32(src0[scale_idx + 1u]);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
let q_lo = f32(q_byte & 0xF) * d + m;
local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q5_0
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 11u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q5_1
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 12u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q8_0
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 17u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1 + block_offset + j];
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q8_1
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 18u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
let aligned = byte_offset & ~3u;
let idx = bbase + aligned / 2u;
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
}
fn byte_of(v: u32, b: u32) -> u32 {
return (v >> (b * 8u)) & 0xFFu;
}
fn sbyte_of(v: u32, b: u32) -> i32 {
let raw = i32((v >> (b * 8u)) & 0xFFu);
return select(raw, raw - 256, raw >= 128);
}
fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let tid = tig / 2u;
let ix = tig % 2u;
let ip = tid / 8u;
let il = tid % 8u;
let l0 = 4u * il;
let is = 8u * ip + l0 / 16u;
let y_offset = 128u * ip + l0;
let q_offset_l = 64u * ip + l0;
let q_offset_h = 32u * ip + l0;
let nb = tile_size / BLOCK_SIZE;
let k_block_start = k_outer / BLOCK_SIZE;
// Aligned scale byte position (is can be odd)
let sc_base_byte = 192u + (is & ~3u);
let sc_byte_pos = is & 3u;
var local_sum = 0.0;
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
let d_raw = load_u32_at(bbase, 208u);
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
let ql1_u32 = load_u32_at(bbase, q_offset_l);
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
for (var l = 0u; l < 4u; l++) {
let y_base = i * BLOCK_SIZE + y_offset + l;
let yl0 = f32(shared_vector[y_base]);
let yl1 = f32(shared_vector[y_base + 32u]);
let yl2 = f32(shared_vector[y_base + 64u]);
let yl3 = f32(shared_vector[y_base + 96u]);
let q1b = byte_of(ql1_u32, l);
let q2b = byte_of(ql2_u32, l);
let qhb = byte_of(qh_u32, l);
let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32);
let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);
sums[0] += yl0 * dq0;
sums[1] += yl1 * dq1;
sums[2] += yl2 * dq2;
sums[3] += yl3 * dq3;
}
local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
sums[2] * f32(sc4) + sums[3] * f32(sc6));
}
return local_sum;
}
#endif
struct MulMatParams {
offset_src0: u32,
offset_src1: u32,
@@ -191,4 +478,3 @@ fn main(
dst[dst_idx / VEC_SIZE] = store_val(group_base);
}
}
+3
View File
@@ -177,6 +177,8 @@ class Keys:
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
KEY_LENGTH_SWA = "{arch}.attention.key_length_swa"
VALUE_LENGTH_SWA = "{arch}.attention.value_length_swa"
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
@@ -188,6 +190,7 @@ class Keys:
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_COUNT_SWA = "{arch}.rope.dimension_count_swa"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
+9
View File
@@ -773,6 +773,12 @@ class GGUFWriter:
def add_value_length_mla(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
def add_key_length_swa(self, length: int) -> None:
self.add_uint32(Keys.Attention.KEY_LENGTH_SWA.format(arch=self.arch), length)
def add_value_length_swa(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_LENGTH_SWA.format(arch=self.arch), length)
def add_indexer_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count)
@@ -946,6 +952,9 @@ class GGUFWriter:
def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
def add_rope_dimension_count_swa(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT_SWA.format(arch=self.arch), count)
def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
+2 -2
View File
@@ -5,7 +5,7 @@ import os
import sys
import subprocess
HTTPLIB_VERSION = "refs/tags/v0.35.0"
HTTPLIB_VERSION = "refs/tags/v0.37.0"
vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
@@ -15,7 +15,7 @@ vendor = {
# not using latest tag to avoid this issue: https://github.com/ggml-org/llama.cpp/pull/17179#discussion_r2515877926
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.24/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/13d161bc8d856ad61ae46b798bbeffc0f49808e8/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/9634bedb5b5a2ca38c1ee7108a9358a4e233f14d/miniaudio.h": "vendor/miniaudio/miniaudio.h",
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "httplib.h",
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/split.py": "split.py",
+4
View File
@@ -230,11 +230,14 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_KEY_LENGTH_SWA, "%s.attention.key_length_swa" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_SWA, "%s.attention.value_length_swa" },
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
@@ -1084,6 +1087,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_CLS_OUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
+3
View File
@@ -234,11 +234,14 @@ enum llm_kv {
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_KEY_LENGTH_SWA,
LLM_KV_ATTENTION_VALUE_LENGTH_SWA,
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
LLM_KV_ATTENTION_INDEXER_TOP_K,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_FREQ_BASE_SWA,
+12 -8
View File
@@ -2876,19 +2876,23 @@ llama_context * llama_init_from_model(
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
const uint32_t blck_size = ggml_blck_size(params.type_k);
if (model->hparams.n_embd_head_k % blck_size != 0) {
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
return nullptr;
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il));
return nullptr;
}
}
}
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
const uint32_t blck_size = ggml_blck_size(params.type_v);
if (model->hparams.n_embd_head_v % blck_size != 0) {
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
return nullptr;
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n",
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il));
return nullptr;
}
}
}
+2 -2
View File
@@ -601,7 +601,7 @@ const char * llama_grammar_parser::parse_sequence(
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
uint64_t min_times = std::stoull(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
uint64_t max_times = UINT64_MAX; // default: no max limit
@@ -614,7 +614,7 @@ const char * llama_grammar_parser::parse_sequence(
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
max_times = std::stoull(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
+5 -5
View File
@@ -250,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
const bool last = (
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
);
for (int i = 0; i < n_tokens; ++i) {
@@ -849,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
ubatch (params.ubatch),
n_embd (hparams.n_embd),
n_layer (hparams.n_layer),
n_rot (hparams.n_rot),
n_rot (hparams.n_rot()),
n_ctx (cparams.n_ctx),
n_head (hparams.n_head()),
n_head_kv (hparams.n_head_kv()),
n_embd_head_k (hparams.n_embd_head_k),
n_embd_head_k (hparams.n_embd_head_k()),
n_embd_k_gqa (hparams.n_embd_k_gqa()),
n_embd_head_v (hparams.n_embd_head_v),
n_embd_head_v (hparams.n_embd_head_v()),
n_embd_v_gqa (hparams.n_embd_v_gqa()),
n_expert (hparams.n_expert),
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
@@ -2552,7 +2552,7 @@ void llm_graph_context::build_pooling(
}
// softmax for qwen3 reranker
if (arch == LLM_ARCH_QWEN3) {
if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
cur = ggml_soft_max(ctx0, cur);
}
} break;
+28 -4
View File
@@ -62,6 +62,14 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const {
return n_head/n_head_kv;
}
uint32_t llama_hparams::n_rot(uint32_t il) const {
if (il < n_layer) {
return is_swa(il) ? n_rot_swa : n_rot_full;
}
GGML_ABORT("fatal error");
}
uint32_t llama_hparams::n_embd_inp() const {
uint32_t n_embd_inp = n_embd;
@@ -76,16 +84,32 @@ uint32_t llama_hparams::n_embd_out() const {
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
}
uint32_t llama_hparams::n_embd_head_k(uint32_t il) const {
if (il < n_layer) {
return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full;
}
GGML_ABORT("fatal error");
}
uint32_t llama_hparams::n_embd_head_v(uint32_t il) const {
if (il < n_layer) {
return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full;
}
GGML_ABORT("fatal error");
}
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
const uint32_t n_head_kv = this->n_head_kv(il);
return n_embd_head_k * n_head_kv;
return n_embd_head_k(il) * n_head_kv;
}
uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
const uint32_t n_head_kv = this->n_head_kv(il);
return n_embd_head_v * n_head_kv;
return n_embd_head_v(il) * n_head_kv;
}
bool llama_hparams::is_n_embd_k_gqa_variable() const {
@@ -197,11 +221,11 @@ bool llama_hparams::is_mla() const {
}
uint32_t llama_hparams::n_embd_head_k_mla() const {
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k();
}
uint32_t llama_hparams::n_embd_head_v_mla() const {
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v();
}
bool llama_hparams::has_kv(uint32_t il) const {
+16 -3
View File
@@ -44,13 +44,20 @@ struct llama_hparams {
uint32_t n_embd;
uint32_t n_layer;
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
uint32_t n_rel_attn_bkts = 0;
// different head size for full_attention and SWA layers
uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head
uint32_t n_embd_head_k_swa;
uint32_t n_embd_head_v_swa;
// different RoPE dimensions for full_attention and SWA layers
uint32_t n_rot_full;
uint32_t n_rot_swa;
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
uint32_t n_embd_head_k_mla_impl = 0;
uint32_t n_embd_head_v_mla_impl = 0;
@@ -247,12 +254,18 @@ struct llama_hparams {
uint32_t n_gqa(uint32_t il = 0) const;
uint32_t n_rot(uint32_t il = 0) const;
// dimension of main + auxiliary input embeddings
uint32_t n_embd_inp() const;
// dimension of output embeddings
uint32_t n_embd_out() const;
// dimension of key/value embeddings for each head (per layer)
uint32_t n_embd_head_k(uint32_t il = 0) const;
uint32_t n_embd_head_v(uint32_t il = 0) const;
// dimension of key embeddings across all k-v heads
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
+14 -16
View File
@@ -1033,8 +1033,8 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
return ggml_view_4d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k),
hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(k->type, hparams.n_embd_head_k(il)),
ggml_row_size(k->type, n_embd_k_gqa),
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
@@ -1056,8 +1056,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
if (!v_trans) {
// note: v->nb[1] <= v->nb[2]
return ggml_view_4d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns,
ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1]
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
@@ -1065,8 +1065,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
// note: v->nb[1] > v->nb[2]
return ggml_view_4d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns,
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1]
ggml_row_size(v->type, kv_size), // v->nb[2]
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
@@ -1544,7 +1544,8 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const {
float freq_scale,
uint32_t il) const {
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
@@ -1552,7 +1553,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
const auto & n_rot = hparams.n_rot;
const auto & n_rot = hparams.n_rot(il);
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
// @ngxson : this is a workaround
// for M-RoPE, we want to rotate the whole vector when doing KV shift
@@ -1606,13 +1607,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
auto * ctx = res->get_ctx();
auto * gf = res->get_gf();
const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
const auto & n_rot = hparams.n_rot;
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
@@ -1626,6 +1620,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const auto n_rot = hparams.n_rot(il);
const auto n_embd_head_k = hparams.n_embd_head_k(il);
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
@@ -1638,7 +1636,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
ggml_row_size(layer.k->type, n_embd_k_gqa),
ggml_row_size(layer.k->type, n_embd_nope));
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
ggml_build_forward_expand(gf, cur);
}
+2 -1
View File
@@ -264,7 +264,8 @@ private:
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
float freq_scale,
uint32_t il) const;
ggml_cgraph * build_graph_shift(
llm_graph_result * res,
+1 -1
View File
@@ -918,7 +918,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
} break;
case GGML_OP_ROPE:
{
const int n_embd_head = hparams.n_embd_head_v;
const int n_embd_head = hparams.n_embd_head_v();
const int n_head = hparams.n_head();
ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512);
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
+6 -3
View File
@@ -186,8 +186,10 @@ void llama_model_saver::add_kv_from_model() {
add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k);
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v);
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full);
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full);
add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
@@ -199,7 +201,8 @@ void llama_model_saver::add_kv_from_model() {
const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full);
add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa);
add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train);
// add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name
add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
+51 -39
View File
@@ -459,26 +459,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
// gpt-j n_rot = rotary_dim
hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
hparams.n_embd_head_k_full = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false);
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
hparams.n_embd_head_v_full = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false);
// sanity check for n_rot (optional)
hparams.n_rot = hparams.n_embd_head_k;
hparams.n_rot_full = hparams.n_embd_head_k_full;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full, false);
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) {
if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
if (hparams.n_rot_full != hparams.n_embd_head_k_full) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot_full, hparams.n_embd_head_k_full));
}
}
} else {
hparams.n_rot = 0;
hparams.n_embd_head_k = 0;
hparams.n_embd_head_v = 0;
hparams.n_rot_full = 0;
hparams.n_embd_head_k_full = 0;
hparams.n_embd_head_v_full = 0;
}
// head size and n_rot for SWA layers
{
hparams.n_embd_head_k_swa = hparams.n_embd_head_k_full;
hparams.n_embd_head_v_swa = hparams.n_embd_head_v_full;
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa, false);
hparams.n_rot_swa = hparams.n_rot_full;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa, false);
}
// for differentiating model types
@@ -1114,10 +1125,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
break;
default: type = LLM_TYPE_UNKNOWN;
}
// Load attention parameters
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
} break;
case LLM_ARCH_PLAMO3:
{
@@ -1212,7 +1219,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
hparams.f_attention_scale = type == LLM_TYPE_27B
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
: 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
} break;
case LLM_ARCH_GEMMA3:
{
@@ -1245,7 +1252,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
hparams.f_attention_scale = type == LLM_TYPE_27B
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
: 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
} break;
case LLM_ARCH_GEMMA3N:
{
@@ -1294,7 +1301,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case 24: type = LLM_TYPE_0_3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
} break;
case LLM_ARCH_STARCODER2:
@@ -2487,7 +2494,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl);
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda);
@@ -2518,6 +2524,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
// full_attention layer only use half of the RoPE dimensions
hparams.n_rot_full = hparams.n_rot_full / 2;
// MoE + SWA parameters
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
@@ -2661,13 +2670,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_embd_head_k = hparams.n_embd_head_k();
const int64_t n_embd_head_v = hparams.n_embd_head_v();
const int64_t n_ff = hparams.n_ff();
const int64_t n_embd_gqa = n_embd_v_gqa;
const int64_t n_vocab = vocab.n_tokens();
const int64_t n_token_types = vocab.n_token_types();
const int64_t n_rot = hparams.n_rot;
const int64_t n_rot = hparams.n_rot();
const int64_t n_expert = hparams.n_expert;
const int64_t n_expert_used = hparams.n_expert_used;
const int64_t n_ctx_train = hparams.n_ctx_train;
@@ -2967,8 +2976,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_MINICPM3:
{
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
const int64_t q_lora_rank = hparams.n_lora_q;
const int64_t kv_lora_rank = hparams.n_lora_kv;
@@ -3840,8 +3849,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
// attention parameters
const uint32_t qk_dim = hparams.n_embd_head_k;
const uint32_t v_dim = hparams.n_embd_head_v;
const uint32_t qk_dim = hparams.n_embd_head_k();
const uint32_t v_dim = hparams.n_embd_head_v();
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -3901,8 +3910,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_PLAMO3:
{
const int64_t head_dim_q = hparams.n_embd_head_k;
const int64_t head_dim_v = hparams.n_embd_head_v;
const int64_t head_dim_q = hparams.n_embd_head_k();
const int64_t head_dim_v = hparams.n_embd_head_v();
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4649,7 +4658,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_SEED_OSS:
{
const uint32_t head_dim = hparams.n_embd_head_k;
const uint32_t head_dim = hparams.n_embd_head_k();
const int64_t n_qo_dim = n_head * head_dim;
const int64_t n_kv_dim = n_head_kv * head_dim;
@@ -4878,7 +4887,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
GGML_ASSERT(n_embd_head_qk_nope >= 1);
@@ -4957,8 +4966,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_PLM:
{
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
const int64_t kv_lora_rank = hparams.n_lora_kv;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -5396,7 +5405,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
const int64_t q_lora_rank = hparams.n_lora_q;
@@ -5680,7 +5689,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_expert = hparams.n_expert;
const int64_t n_expert_used = hparams.n_expert_used;
const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp;
const int64_t head_dim = hparams.n_embd_head_k;
const int64_t head_dim = hparams.n_embd_head_k();
const int64_t n_qo_dim = n_head * head_dim;
const int64_t n_kv_dim = n_head_kv * head_dim;
@@ -6968,7 +6977,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA)
// Note: hparams.n_rot may be 72 (from conversion) but actual is 64
const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim
const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0);
// Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled)
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i),
@@ -7339,7 +7348,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer.
uint32_t n_rot_max = 0;
for (int i = 0; i < n_layer; ++i) {
n_rot_max = std::max(n_rot_max, hparams.n_rot);
n_rot_max = std::max(n_rot_max, hparams.n_rot(i));
}
if (n_rot_max == 0) {
n_rot_max = n_rot;
@@ -7674,11 +7683,11 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
@@ -7702,6 +7711,9 @@ void llama_model::print_info() const {
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa);
LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa);
LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa);
LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa);
LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa);
}
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
+470 -286
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -1,8 +1,8 @@
#include "models.h"
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -3,10 +3,10 @@
llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -2,10 +2,10 @@
llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -1,10 +1,10 @@
#include "models.h"
llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -2,10 +2,10 @@
llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+2 -2
View File
@@ -2,10 +2,10 @@
llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;
+2 -2
View File
@@ -1,10 +1,10 @@
#include "models.h"
llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;
+2 -2
View File
@@ -2,9 +2,9 @@
llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;
+2 -2
View File
@@ -1,10 +1,10 @@
#include "models.h"
llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -3,10 +3,10 @@
#include <float.h>
llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+2 -2
View File
@@ -2,10 +2,10 @@
llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -1,11 +1,11 @@
#include "models.h"
llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -2,11 +2,11 @@
llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * inpL;
ggml_tensor * cur;
+2 -2
View File
@@ -1,9 +1,9 @@
#include "models.h"
llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
const float f_logit_scale = hparams.f_logit_scale;
+2 -2
View File
@@ -4,9 +4,9 @@
llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
const float f_logit_scale = hparams.f_logit_scale;
+3 -3
View File
@@ -1,11 +1,11 @@
#include "models.h"
llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -3,10 +3,10 @@
llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -2,10 +2,10 @@
llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+1 -1
View File
@@ -8,7 +8,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
+3 -3
View File
@@ -2,10 +2,10 @@
llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -5,10 +5,10 @@
llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
//copied from qwen2
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -2,10 +2,10 @@
llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -2,10 +2,10 @@
llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+2 -2
View File
@@ -1,9 +1,9 @@
#include "models.h"
llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -2,10 +2,10 @@
llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -4,10 +4,10 @@
llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -4,10 +4,10 @@
template <bool iswa>
llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+1 -1
View File
@@ -2,7 +2,7 @@
llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) :
llm_build_mamba_base(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
ggml_tensor * cur;
ggml_tensor * inpL;
+3 -3
View File
@@ -2,11 +2,11 @@
llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
GGML_ASSERT(n_embd_head == n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
+1 -1
View File
@@ -2,7 +2,7 @@
llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
ggml_tensor * cur;
ggml_tensor * inpL;
+1 -1
View File
@@ -2,7 +2,7 @@
llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
ggml_tensor * cur;
ggml_tensor * inpL;
+1 -1
View File
@@ -1,7 +1,7 @@
#include "models.h"
llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
ggml_tensor * cur;
ggml_tensor * inpL;
+1 -1
View File
@@ -2,7 +2,7 @@
template <bool iswa>
llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_k;
const int64_t n_embd_head = hparams.n_embd_head_k();
ggml_tensor * cur;
ggml_tensor * inpL;
+1 -1
View File
@@ -3,7 +3,7 @@
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params),
model(model),
n_embd_head(model.hparams.n_embd_head_k),
n_embd_head(model.hparams.n_embd_head_k()),
n_embd_altup(model.hparams.n_embd_altup),
n_altup(model.hparams.n_altup),
i_altup_act(model.hparams.i_altup_act) {
+2 -2
View File
@@ -1,9 +1,9 @@
#include "models.h"
llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+2 -2
View File
@@ -3,10 +3,10 @@
llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+2 -2
View File
@@ -1,10 +1,10 @@
#include "models.h"
llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_head = hparams.n_embd_head_v();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * cur;
ggml_tensor * pos;

Some files were not shown because too many files have changed in this diff Show More