Compare commits

...

22 Commits

Author SHA1 Message Date
Max Krasnyansky 13e673863b hexagon: flash attention rework (optimizations, accuracy improvements, etc) (#25085)
* hex-mm: fold mm quant tasks into the main matmul threads

* hex-mm: minor formatting fixes

* hex-mm: cleanup is_quant checks in dma dispatch

* hex-mm: fix dst-spad alignment

* hex-mm: move fp kernels in the hvx-mm-kernels header

* hex-mm: fuse with ADD

* hex-fa: factor out ukernels into separate headers and unify the rest

* hex-fa: move kernel-params compute into the host

* hex-fa: refactor vtcm alloc for consistency

* hex-fa: add support for FA_SELECT

* hex-fa: update tracing insrumentation to cover all functions

* hex-fa: update hvx fallback thresholds to recover t/g regressions

* hex-fa: update tracing instrumentation

* hex-fa: improved tracing with additional events

* hex-fa: optimize mask processing (fastdiv, etc)

* hex-fa: improve mask dma caching

* hmx-fa: change loop order to maximize mask cache hits

* hex-fa: remove over instrumentation

* hex-fa: breakdown QKV prep trace events

* hmx-fa: further mask proc optimizations

* hex-fa: mask broadcast is the common case, optimize for that

* hex-fa: use aligned loads where possible

* hex-fa: update loops to use uint32_t indices

* hmx-fa: fold vtcm init into q prep task

* hex-fa: update rest of the hmx funcs to use uint32_t

* hmx-fa: fold build_d into the main softmax loop

* hmx-fa: start kv dmas earlier

* hmx-fa: start mask dma a bit earlier

* hex-fa: precompute rows per task to avoid divs

* hmx-fa: specialize fa_o_store for f16 and f32

* hmx-fa: prelim support for Sinks

* hmx-fa: keep softmax accumulators in fp32

* hex-fa: add tanh_f16 and exp2_f16 and use that in FA

* hex-fa: use fp16 math in the hvx kernel

* hex-fa: avoid expensive float -> __fp16 cast for slopes and softcap

* hex-fa: replace most vec_exp_f32 with vec_exp2_f16

* hmx-fa: vectorize sinks update

* hex-fa: minor formatting

* hmx-fa: fold softcap loop into the tile load

* hmx-fa: use vectoralias to populate sinks

* hex-fa: remove redudant check

* hex-fa: fix vtcm size compute to use fp32 for accumulators

* hex-mm: fix trailing spaces

* hmx-fa: dont use -inf to init mask to avoid conversion overflows

* hex-fa: no need to explicitly guard -inf in the f16->f32 converter now

* hmx-fa: cleanup fa sinks handling

* hex-mm: fixed src2 stride handling when mm is fused with add

* hex-fa: make lto happy
2026-07-01 06:59:19 -07:00
Johannes Gäßler b820cc8e6f CUDA: consistent use of __restrict__ + PDL for FA (#25185) 2026-07-01 10:55:14 +02:00
ragz4125 6dbc1174b8 ggml-cpu: add AVX2 optimization for nvfp4 dot product and use UE4M3 LUT (#23961) 2026-07-01 15:31:20 +08:00
Aleksander Grygier 9d88e7cedd ui Prevent tool messages from incorrectly appending to other conversations (#25177)
* fix: Prevent tool messages from incorrectly appending to other conversations

* ui: prevent agentic loop from poisoning another conv's currNode

* ui: make editedContent a  so background recompute does not wipe in-progress edits

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-07-01 09:25:18 +02:00
Aleksander Grygier 7af4279f45 ui: Remove PWA navigate fallback to prevent caching API endpoint requests (#25174) 2026-07-01 07:32:55 +02:00
lhez fd1a05791d opencl: initial q1_0 support (#25160)
* opencl: general q1_0 support

* opencl: add Adreno GEMM/GEMV for q1_0
2026-06-30 21:43:20 -07:00
fairydreaming 0eca4d490e cuda : prevent integer truncation and overflow errors when using KQ mask strides in flash_attn_mask_to_KV_max kernel (#24945)
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
2026-06-30 20:47:05 +02:00
Jürgen Schmied 4f31eedb0c model : register t_layer_inp for qwen3next (#25141)
* Fix input assignment in layer processing loop

Fix DFLASH for qwen-coder-next

* add line break

Added tensor for attention normalization in Qwen3 model.
2026-06-30 17:57:14 +02:00
Pascal 799fcc04a5 common,server: handle bracketed IPv6 literals in URL authority (#25140)
* common,server: handle bracketed IPv6 literals in URL authority

Parse the [host]:port form (RFC 3986) and bracket IPv6 hosts when
formatting a URL authority: listening log, proxy Host header, proxy
log, client rebuild. The per-request remote_addr stays bare.

* common: restore unsupported scheme throw in url parser

Address @ngxson review: keep the explicit reject in port resolution so
the block stays self-contained. Non-http(s) schemes still throw (also
gated at the top of common_http_parse_url).
2026-06-30 16:16:44 +02:00
Matt Jallo 931eb37f8c CUDA: fix get_rows_back for tables with more than 65535 rows (grid-y clamp + stride) (#25103) 2026-06-30 14:16:24 +02:00
Johannes Gäßler e495d1e748 CUDA: fix Gemma E4B MTP FlashAttention (#25148)
* CUDA: fix Gemma E4B MTP FlashAttention

* remove unused template declaration
2026-06-30 14:06:54 +02:00
Kevin Liu f708a5b2ca vulkan: roll bk loop in matmul for asahi linux (#24663)
* vulkan: roll bk loop in matmul for asahi linux

* vulkan: fix inline comment

* vulkan: revert BK-loop unroll change

* vulkan: edit spirv directly for asahi roll bk loop

* vulkan: remove trailing whitespace at the end of comments
2026-06-30 12:27:38 +02:00
zduford d9df11006f HIP: use hipBLAS for dense prefill on gfx900, keep MMQ for MoE (#24588)
* HIP: keep MMQ for gfx900 MoE and Q8_0, use hipBLAS for dense K-quants

Assisted-by: GitHub Copilot CLI

* HIP: tighten conditional block to be explicitly for gfx900

* HIP: Further simplified gfx900 conditional block

* removed unnecessary comment
2026-06-30 11:51:38 +02:00
Masashi Yoshimura 6c5de1cc83 ggml-webgpu: add support for NVFP4 (#25143) 2026-06-30 17:20:04 +09:00
Oliver Simons 86b94708f2 Revert "sched : reintroduce less synchronizations during split compute (#20793)" (#25138) 2026-06-30 08:41:45 +08:00
Adrien Gallouët 6f4f53f2b7 common : dedup preset and cached model entries in /v1/models (#25131)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-29 17:37:23 +02:00
Ruben Ortlam 25a1d63f43 vulkan: use flops instead of weight tensor size for submission heuristic (#25005)
* vulkan: extract flops calculation into function

* use flops instead of matmul src0 tensor size for submission threshold

* use unsigned ints
2026-06-29 15:24:44 +02:00
Aman Gupta 8c146a8366 DeepSeek V4 (#24162)
* convert: add dsv4 conversion

* add basic setup

* add llm_graph_input_dsv4

* add save-load state

* add sinkhorn eps - correction by @fairydreaming

* add rope fix

* cleanup dead code

* fix bugs

* support pro model: added by @fairydreaming

* remove redundant V cache

* Chat template

* remove debugging leftovers

* Add mechanism for inlining templates based on architecture

* s/deepseek-v4-flash/deepseek4/g

* s/deepseek-v4-flash/deepseek4/g continued

* enable graph reuse

* enable FA

* fix test llama archs

* rename

* compatibility with antirez ds4 GGUFs

* simplified set_gguf_parameters() by calling super class method, replaced moe.score_func with expert_gating_func.

* reserve worst-case kv-cache

* revert max split inputs

* address review comments

* add padding to enable FA

* pad only the final value of plan.n_kv to 256

* remove built-in cpp chat template

* cont: remove cpp built-in template

* rm outdated test

* replace ggml_view_3d() with ggml_reshape_3d()

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* only support n_seq=1 for now

* remove unused var

* cont: remove unused var

* use scale bias

* use correct ptr for can_reuse

* remove gen-chat-inline-templates.py

* simplify graph reuse

* cont: cleanup

* remove unused inputs

* enable partial checkpointing

* add correct shape for kq_mask + set llama_model_n_swa to 0 for dsv4

* precompute source_idx + add comment about dummy write

* support multi-seq

* remove restored_trim_pos

* use split_equal when possible

* fix indent

* address review comments

* use LLM_KV

* fix ci

---------

Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: fairydreaming <166155368+fairydreaming@users.noreply.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-29 16:58:51 +08:00
seryogakovalyov 6cb18b2f2e tools/ui: restore Tailwind scanning in ignored worktrees (#24879) 2026-06-29 10:55:52 +02:00
o7si 277a105dc8 common : remove unused regex-partial (#25118) 2026-06-29 08:48:39 +02:00
Xuan-Son Nguyen b3fed31b99 jinja, chat: add --reasoning-preserve flag (#25105)
* jinja, chat: add --reasoning-preserve flag

* correct help message
2026-06-28 23:33:51 +02:00
Aleksander Grygier dbdaece23d Revert "ui: fix accessibility for hover-gated interactive elements assisted by claude(in debugging and tests) (#24727)" (#25098) 2026-06-28 21:30:03 +02:00
109 changed files with 10637 additions and 3834 deletions
-2
View File
@@ -94,10 +94,8 @@ add_library(${TARGET}
peg-parser.h
preset.cpp
preset.h
regex-partial.cpp
reasoning-budget.cpp
reasoning-budget.h
regex-partial.h
sampling.cpp
sampling.h
speculative.cpp
+14
View File
@@ -3296,6 +3296,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
{"--reasoning-preserve"},
{"--no-reasoning-preserve"},
"preserve reasoning trace in the full history, not just the last assistant message (default: template default)\n"
"compatible with certain templates having 'supports_preserve_reasoning' capability\n"
"example: https://docs.z.ai/guides/capabilities/thinking-mode#preserved-thinking",
[](common_params & params, bool value) {
if (value) {
params.default_template_kwargs["preserve_reasoning"] = "true";
} else {
params.default_template_kwargs["preserve_reasoning"] = "false";
}
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING_PRESERVE"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
+4
View File
@@ -912,6 +912,10 @@ static std::string common_chat_template_direct_apply_impl(
if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true;
}
if (inp.contains("preserve_reasoning") && inp["preserve_reasoning"].is_boolean()) {
bool enabled = inp["preserve_reasoning"].get<bool>();
jinja::caps_apply_preserve_reasoning(ctx, enabled);
}
jinja::global_from_json(ctx, inp, inputs.mark_input);
+28 -6
View File
@@ -11,6 +11,11 @@ struct common_http_url {
std::string path;
};
// bracket an IPv6 literal host for a URL authority (RFC 3986)
static std::string common_http_format_host(const std::string & host) {
return host.find(':') != std::string::npos ? "[" + host + "]" : host;
}
static common_http_url common_http_parse_url(const std::string & url) {
common_http_url parts;
auto scheme_end = url.find("://");
@@ -49,11 +54,28 @@ static common_http_url common_http_parse_url(const std::string & url) {
parts.path = "/";
}
auto colon_pos = parts.host.find(':');
// split the authority into host and optional port, a bracketed IPv6 literal keeps its inner colons (RFC 3986)
std::string port_str;
if (!parts.host.empty() && parts.host.front() == '[') {
auto close = parts.host.find(']');
if (close == std::string::npos) {
throw std::runtime_error("invalid IPv6 URL authority: " + parts.host);
}
auto after = parts.host.substr(close + 1);
if (!after.empty() && after.front() == ':') {
port_str = after.substr(1);
}
parts.host = parts.host.substr(1, close - 1);
} else {
auto colon_pos = parts.host.find(':');
if (colon_pos != std::string::npos) {
port_str = parts.host.substr(colon_pos + 1);
parts.host = parts.host.substr(0, colon_pos);
}
}
if (colon_pos != std::string::npos) {
parts.port = std::stoi(parts.host.substr(colon_pos + 1));
parts.host = parts.host.substr(0, colon_pos);
if (!port_str.empty()) {
parts.port = std::stoi(port_str);
} else if (parts.scheme == "http") {
parts.port = 80;
} else if (parts.scheme == "https") {
@@ -83,7 +105,7 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
}
#endif
httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port));
httplib::Client cli(parts.scheme + "://" + common_http_format_host(parts.host) + ":" + std::to_string(parts.port));
if (!parts.user.empty()) {
cli.set_basic_auth(parts.user, parts.password);
@@ -95,5 +117,5 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
}
static std::string common_http_show_masked_url(const common_http_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + common_http_format_host(parts.host) + parts.path;
}
+44 -23
View File
@@ -16,22 +16,34 @@ using json = nlohmann::ordered_json;
namespace jinja {
using caps_json_fn = std::function<json()>;
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
using caps_ctx_fn = std::function<void(context &)>;
using caps_analyze_fn = std::function<void(bool, value &, value &, const std::string &)>;
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled) {
ctx.set_val("preserve_thinking", mk_val<value_bool>(enabled));
ctx.set_val("clear_thinking", mk_val<value_bool>(!enabled));
ctx.set_val("truncate_history_thinking", mk_val<value_bool>(!enabled));
}
static void caps_try_execute(jinja::program & prog,
const caps_json_fn & messages_fn,
const caps_ctx_fn & ctx_fn,
const caps_json_fn & tools_fn,
const caps_analyze_fn & analyze_fn) {
context ctx;
ctx.is_get_stats = true;
jinja::global_from_json(ctx, json{
{"messages", messages_fn()},
{"tools", tools_fn()},
{"tools", tools_fn ? tools_fn() : json::array()},
{"bos_token", ""},
{"eos_token", ""},
{"add_generation_prompt", true}
}, true);
if (ctx_fn) {
ctx_fn(ctx);
}
auto messages = ctx.get_val("messages");
auto tools = ctx.get_val("tools");
@@ -49,7 +61,7 @@ static void caps_try_execute(jinja::program & prog,
// ignore exceptions during capability analysis
}
analyze_fn(success, messages, tools);
analyze_fn(success, messages, tools, result);
}
// for debugging only
@@ -109,11 +121,9 @@ caps caps_get(jinja::program & prog) {
}
});
},
[&]() {
// tools
return json{nullptr};
},
[&](bool success, value & messages, value &) {
nullptr, // ctx_fn
nullptr, // tools_fn
[&](bool success, value & messages, value &, const std::string &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
@@ -145,11 +155,9 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&]() {
// tools
return json::array();
},
[&](bool, value & messages, value &) {
nullptr, // ctx_fn
nullptr, // tools_fn
[&](bool, value & messages, value &, const std::string &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (!content->stats.used) {
@@ -201,6 +209,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
@@ -224,7 +233,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & tools) {
[&](bool success, value & messages, value & tools, const std::string &) {
if (!success) {
return; // Nothing can be inferred
}
@@ -293,6 +302,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
@@ -316,7 +326,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & tools) {
[&](bool success, value & messages, value & tools, const std::string &) {
if (!success) {
result.supports_tool_calls = false;
result.supports_tools = false;
@@ -394,6 +404,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
@@ -417,7 +428,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & /*tools*/) {
[&](bool success, value & messages, value &, const std::string &) {
if (!success) {
result.supports_parallel_tool_calls = false;
return;
@@ -438,11 +449,22 @@ caps caps_get(jinja::program & prog) {
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");
// case: preserve reasoning content in chat history
const std::string reasoning_placeholder = "<REASONING_CONTENT_PLACEHOLDER>";
caps_try_execute(
prog,
[&]() {
// messages
return json::array({
{
{"role", "user"},
{"content", "User message"}
},
{
{"role", "assistant"},
{"content", "Assistant message"},
// check of reasoning_content deeper in the history, not just the last assistant message
{"reasoning_content", reasoning_placeholder}
},
{
{"role", "user"},
{"content", "User message"}
@@ -458,14 +480,13 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&]() {
// tools
return json::array();
[&](context & ctx) {
caps_apply_preserve_reasoning(ctx, true);
},
[&](bool, value & messages, value &) {
auto & content = messages->at(1)->at("reasoning_content");
caps_print_stats(content, "messages[1].reasoning_content");
if (content->stats.used) {
nullptr, // tools_fn
[&](bool, value &, value &, const std::string & output) {
// note: we cannot use stats here because the reasoning_content may be used for "if" condition test, but not actually outputted in the final result
if (output.find(reasoning_placeholder) != std::string::npos) {
result.supports_preserve_reasoning = true;
}
}
+5 -1
View File
@@ -12,7 +12,9 @@ struct caps {
bool supports_tool_calls = true;
bool supports_system_role = true;
bool supports_parallel_tool_calls = true;
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
// supports preserve reasoning trace in the full history, not just the last assistant message
bool supports_preserve_reasoning = false;
// one of the 2 content capabilities must be true
bool supports_string_content = true;
@@ -29,4 +31,6 @@ struct caps {
caps caps_get(jinja::program & prog);
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled);
} // namespace jinja
+29 -4
View File
@@ -7,6 +7,7 @@
#include <fstream>
#include <sstream>
#include <filesystem>
#include <regex>
static std::string rm_leading_dashes(const std::string & str) {
size_t pos = 0;
@@ -16,6 +17,23 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos);
}
static std::string canonical_tag(const std::string & tag) {
static const std::regex re_tag("[-.]([A-Z0-9_]+)$", std::regex::icase);
std::smatch m;
if (std::regex_search(tag, m, re_tag)) {
std::string canon = m[1].str();
for (char & c : canon) {
c = (char) std::toupper((unsigned char) c);
}
return canon;
}
std::string upper = tag;
for (char & c : upper) {
c = (char) std::toupper((unsigned char) c);
}
return upper;
}
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
std::vector<std::string> args;
@@ -270,11 +288,18 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
for (auto section : ini_data) {
common_preset preset;
if (section.first.empty()) {
preset.name = COMMON_PRESET_DEFAULT_NAME;
} else {
preset.name = section.first;
std::string section_name = section.first.empty() ? std::string(COMMON_PRESET_DEFAULT_NAME) : section.first;
if (section_name != "*" && section_name != COMMON_PRESET_DEFAULT_NAME) {
auto colon_idx = section_name.rfind(':');
if (colon_idx != std::string::npos) {
std::string tag = section_name.substr(colon_idx + 1);
std::string canon_tag = canonical_tag(tag);
if (canon_tag != tag) {
section_name = section_name.substr(0, colon_idx + 1) + canon_tag;
}
}
}
preset.name = section_name;
LOG_DBG("loading preset: %s\n", preset.name.c_str());
for (const auto & [key, value] : section.second) {
if (key == "version") {
-204
View File
@@ -1,204 +0,0 @@
#include "regex-partial.h"
#include "common.h"
#include <functional>
#include <optional>
common_regex::common_regex(const std::string & pattern) :
pattern(pattern),
rx(pattern),
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
std::smatch match;
if (pos > input.size()) {
throw std::runtime_error("Position out of bounds");
}
auto start = input.begin() + pos;
auto found = as_match
? std::regex_match(start, input.end(), match, rx)
: std::regex_search(start, input.end(), match, rx);
if (found) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
for (size_t i = 0; i < match.size(); ++i) {
auto begin = pos + match.position(i);
res.groups.emplace_back(begin, begin + match.length(i));
}
return res;
}
std::match_results<std::string::const_reverse_iterator> srmatch;
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
auto group = srmatch[1].str();
if (group.length() != 0) {
auto it = srmatch[1].second.base();
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
if ((!as_match) || it == input.begin()) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
const size_t begin = std::distance(input.begin(), it);
const size_t end = input.size();
if (begin == std::string::npos || end == std::string::npos || begin > end) {
throw std::runtime_error("Invalid range");
}
res.groups.push_back({begin, end});
return res;
}
}
}
return {};
}
/*
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
- /a|b/ -> ^(a|b)
- /a*?/ -> error, could match ""
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
- /.*?ab/ -> ^((?:b)?a) (omit .*)
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
*/
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin();
const auto end = pattern.end();
std::function<std::string()> process = [&]() {
std::vector<std::vector<std::string>> alternatives(1);
std::vector<std::string> * sequence = &alternatives.back();
while (it != end) {
if (*it == '[') {
auto start = it;
++it;
while (it != end) {
if ((*it == '\\') && (++it != end)) {
++it;
} else if ((it != end) && (*it == ']')) {
break;
} else {
++it;
}
}
if (it == end) {
throw std::runtime_error("Unmatched '[' in pattern");
}
++it;
sequence->push_back(std::string(start, it));
} else if (*it == '*' || *it == '?' || *it == '+') {
if (sequence->empty()) {
throw std::runtime_error("Quantifier without preceding element");
}
sequence->back() += *it;
auto is_star = *it == '*';
++it;
if (is_star) {
if (it != end && *it == '?') {
++it;
}
}
} else if (*it == '{') {
if (sequence->empty()) {
throw std::runtime_error("Repetition without preceding element");
}
++it;
auto start = it;
while (it != end && *it != '}') {
++it;
}
if (it == end) {
throw std::runtime_error("Unmatched '{' in pattern");
}
auto parts = string_split(std::string(start, it), ",");
++it;
if (parts.size() > 2) {
throw std::runtime_error("Invalid repetition range in pattern");
}
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
if (s.empty()) {
return def;
}
return std::stoi(s);
};
auto min = parseOptInt(parts[0], 0);
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
if (min && max && *max < *min) {
throw std::runtime_error("Invalid repetition range in pattern");
}
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
auto part = sequence->back();
sequence->pop_back();
for (int i = 0; i < *min; i++) {
sequence->push_back(part);
}
if (max) {
for (int i = *min; i < *max; i++) {
sequence->push_back(part + "?");
}
} else {
sequence->push_back(part + "*");
}
} else if (*it == '(') {
++it;
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
it += 2;
}
auto sub = process();
if (*it != ')') {
throw std::runtime_error("Unmatched '(' in pattern");
}
++it;
auto & part = sequence->emplace_back("(?:");
part += sub;
part += ")";
} else if (*it == ')') {
break;
} else if (*it == '|') {
++it;
alternatives.emplace_back();
sequence = &alternatives.back();
} else if (*it == '\\' && (++it != end)) {
auto str = std::string("\\") + *it;
sequence->push_back(str);
++it;
} else if (it != end) {
sequence->push_back(std::string(1, *it));
++it;
}
}
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function.
std::vector<std::string> res_alts;
for (const auto & parts : alternatives) {
auto & res = res_alts.emplace_back();
for (size_t i = 0; i < parts.size() - 1; i++) {
res += "(?:";
}
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
res += *it;
if (it != parts.rend() - 1) {
res += ")?";
}
}
}
return string_join(res_alts, "|");
};
auto res = process();
if (it != end) {
throw std::runtime_error("Unmatched '(' in pattern");
}
return "^(" + res + ")";
}
-56
View File
@@ -1,56 +0,0 @@
#pragma once
#include <regex>
#include <string>
enum common_regex_match_type {
COMMON_REGEX_MATCH_TYPE_NONE,
COMMON_REGEX_MATCH_TYPE_PARTIAL,
COMMON_REGEX_MATCH_TYPE_FULL,
};
struct common_string_range {
size_t begin;
size_t end;
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
if (begin > end) {
throw std::runtime_error("Invalid range");
}
}
// prevent default ctor
common_string_range() = delete;
bool empty() const {
return begin == end;
}
bool operator==(const common_string_range & other) const {
return begin == other.begin && end == other.end;
}
};
struct common_regex_match {
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
std::vector<common_string_range> groups;
bool operator==(const common_regex_match & other) const {
return type == other.type && groups == other.groups;
}
bool operator!=(const common_regex_match & other) const {
return !(*this == other);
}
};
class common_regex {
std::string pattern;
std::regex rx;
std::regex rx_reversed_partial;
public:
explicit common_regex(const std::string & pattern);
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
const std::string & str() const { return pattern; }
};
// For testing only (pretty print of failures).
std::string regex_to_reversed_partial_regex(const std::string & pattern);
+1
View File
@@ -51,6 +51,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
"DFlashDraftModel": "qwen",
"DeepseekV4ForCausalLM": "deepseek",
"DistilBertForMaskedLM": "bert",
"DistilBertForSequenceClassification": "bert",
"DistilBertModel": "bert",
+14 -1
View File
@@ -1273,7 +1273,7 @@ class TextModel(ModelBase):
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None:
if (n_experts := self.find_hparam(["num_local_experts", "num_experts", "n_routed_experts"], optional=True)) is not None:
self.gguf_writer.add_expert_count(n_experts)
logger.info(f"gguf: expert count = {n_experts}")
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None:
@@ -1291,6 +1291,8 @@ class TextModel(ModelBase):
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
elif score_func == "sqrtsoftplus":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SQRTSOFTPLUS)
else:
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
logger.info(f"gguf: expert score gating function = {score_func}")
@@ -2600,6 +2602,17 @@ class LazyTorchTensor(gguf.LazyBase):
return cls._wrap_fn(func)(*args, **kwargs)
if hasattr(torch, "float8_e8m0fnu"):
_torch_float8_e8m0 = torch.float8_e8m0fnu
LazyTorchTensor._dtype_map[_torch_float8_e8m0] = np.uint8
LazyTorchTensor._dtype_byteswap_map[_torch_float8_e8m0] = np.uint8
LazyTorchTensor._dtype_str_map["F8_E8M0"] = _torch_float8_e8m0
else:
# Older torch builds do not expose F8_E8M0. Keep the raw bytes so callers
# that know the format can decode them explicitly.
LazyTorchTensor._dtype_str_map["F8_E8M0"] = torch.uint8
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
# maybe we should fallback to text model's arch in that case, since not many models have both
+308 -1
View File
@@ -1,15 +1,18 @@
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Any, Callable, Iterable, TYPE_CHECKING
import numpy as np
import torch
if TYPE_CHECKING:
from torch import Tensor
from .base import MmprojModel, ModelBase, TextModel, gguf, logger
from .base import LazyTorchTensor, MmprojModel, ModelBase, TextModel, gguf, logger
from .qwen import QwenModel
@@ -467,3 +470,307 @@ class DeepseekV32Model(DeepseekV2Model):
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
@ModelBase.register("DeepseekV4ForCausalLM")
class DeepseekV4Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK4
_skipped_mtp_tensors = 0
def __init__(self, *args, **kwargs):
type(self)._skipped_mtp_tensors = 0
super().__init__(*args, **kwargs)
with open(self.dir_model / "config.json", "r", encoding="utf-8") as f:
raw_hparams = json.load(f)
for key, value in raw_hparams.items():
self.hparams.setdefault(key, value)
self.block_count = self.hparams["num_hidden_layers"]
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self._dsv4_fp8_dequantized: set[str] = set()
self._dsv4_bf16_tensors: set[str] = set()
self._dsv4_f32_tensors: set[str] = set()
self._dsv4_mxfp4_generated = False
self._collect_source_dtypes()
if type(self)._skipped_mtp_tensors:
logger.info("Skipping %d DeepSeek-V4 MTP tensor(s) for conversion v0", type(self)._skipped_mtp_tensors)
# add a default chat template; if the model has a built-in template, it will be overridden later
template_path = Path(__file__).parent.parent / "models" / "templates" / "deepseek-ai-DeepSeek-V4.jinja"
if template_path.is_file():
with open(template_path, "r", encoding="utf-8") as f:
self.gguf_writer.add_chat_template(f.read())
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, _ = item
if name.startswith("mtp."):
cls._skipped_mtp_tensors += 1
return None
return super().filter_tensors(item)
@staticmethod
def _float8_dtypes() -> tuple[torch.dtype, ...]:
return tuple(
dtype for dtype in (
getattr(torch, "float8_e4m3fn", None),
getattr(torch, "float8_e5m2", None),
) if dtype is not None
)
@staticmethod
def _e8m0_to_float(scale: Tensor) -> Tensor:
torch_float8_e8m0 = getattr(torch, "float8_e8m0fnu", None)
if torch_float8_e8m0 is not None and scale.dtype == torch_float8_e8m0:
return scale.float()
bits = scale.view(torch.uint8).float()
return torch.exp2(bits - 127.0)
def _collect_source_dtypes(self) -> None:
for name, gen in self.model_tensors.items():
dtype = gen().dtype
if dtype == torch.bfloat16:
self._dsv4_bf16_tensors.add(name)
elif dtype == torch.float32:
self._dsv4_f32_tensors.add(name)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
self.gguf_writer.add_swiglu_clamp_exp([hparams["swiglu_limit"]] * self.block_count)
self.gguf_writer.add_swiglu_clamp_shexp([hparams["swiglu_limit"]] * self.block_count)
self.gguf_writer.add_indexer_head_count(hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(hparams["index_topk"])
self.gguf_writer.add_attention_output_group_count(hparams["o_groups"])
self.gguf_writer.add_attention_output_lora_rank(hparams["o_lora_rank"])
self.gguf_writer.add_attention_compress_ratios(hparams["compress_ratios"])
self.gguf_writer.add_attention_compress_rope_freq_base(hparams["compress_rope_theta"])
self.gguf_writer.add_hyper_connection_count(hparams["hc_mult"])
self.gguf_writer.add_hyper_connection_sinkhorn_iterations(hparams["hc_sinkhorn_iters"])
self.gguf_writer.add_hyper_connection_epsilon(hparams["hc_eps"])
self.gguf_writer.add_hash_layer_count(hparams["num_hash_layers"])
def dequant_model(self):
fp8_dtypes = self._float8_dtypes()
tensors_to_remove: list[str] = []
def dequant_fp8_weight(weight: Tensor, scale: Tensor) -> Tensor:
out_features, in_features = weight.shape
scale_f = self._e8m0_to_float(scale)
scale_f = scale_f.repeat_interleave(128, 0)[:out_features]
scale_f = scale_f.repeat_interleave(128, 1)[:, :in_features]
return weight.float() * scale_f
for name in list(self.model_tensors.keys()):
if not name.endswith(".scale"):
continue
weight_name = name.removesuffix(".scale") + ".weight"
if weight_name not in self.model_tensors:
continue
weight = self.model_tensors[weight_name]
scale = self.model_tensors[name]
if weight().dtype not in fp8_dtypes:
continue
self.model_tensors[weight_name] = lambda w=weight, s=scale: dequant_fp8_weight(w(), s())
self._dsv4_fp8_dequantized.add(weight_name)
tensors_to_remove.append(name)
for name in tensors_to_remove:
del self.model_tensors[name]
@staticmethod
def _pack_mxfp4_blocks(weight: Tensor, scale: Tensor) -> np.ndarray:
packed = weight.contiguous().view(torch.uint8)
scale_u8 = scale.contiguous().view(torch.uint8)
out_features, packed_cols = packed.shape
logical_cols = packed_cols * 2
if logical_cols % 32 != 0:
raise ValueError(f"MXFP4 source row has {logical_cols} values, expected a multiple of 32")
n_blocks = logical_cols // 32
if tuple(scale_u8.shape) != (out_features, n_blocks):
raise ValueError(f"MXFP4 scale shape {tuple(scale_u8.shape)} does not match {(out_features, n_blocks)}")
src = packed.reshape(out_features, n_blocks, 16)
low = src & 0x0F
high = (src >> 4) & 0x0F
# The safetensors bytes store adjacent values as low/high nibbles.
# ggml MXFP4 blocks store values 0..15 in low nibbles and 16..31 in high nibbles.
vals = torch.stack((low, high), dim=-1).reshape(out_features, n_blocks, 32)
qs = vals[:, :, :16] | (vals[:, :, 16:] << 4)
raw = torch.cat((scale_u8.unsqueeze(-1), qs.to(torch.uint8)), dim=-1)
return raw.reshape(out_features, n_blocks * 17).cpu().numpy()
def _write_mxfp4_expert_tensor(self, bid: int, proj: str, tensor_key: gguf.MODEL_TENSOR) -> list[str]:
n_experts = self.hparams["n_routed_experts"]
data: np.ndarray | None = None
consumed: list[str] = []
for eid in range(n_experts):
weight_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.weight"
scale_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.scale"
if weight_name not in self.model_tensors or scale_name not in self.model_tensors:
raise KeyError(f"Missing routed expert tensors for {weight_name}")
weight = LazyTorchTensor.to_eager(self.model_tensors[weight_name]())
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
packed = self._pack_mxfp4_blocks(weight, scale)
if data is None:
data = np.empty((n_experts, *packed.shape), dtype=packed.dtype)
data[eid] = packed
consumed.extend((weight_name, scale_name))
assert data is not None
new_name = self.format_tensor_name(tensor_key, bid)
shape = gguf.quant_shape_from_byte_shape(data.shape, gguf.GGMLQuantizationType.MXFP4)
logger.info(f"{new_name}: repacked routed experts to MXFP4, shape = {{{', '.join(str(n) for n in reversed(shape))}}}")
self.gguf_writer.add_tensor(new_name, data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
return consumed
def _write_hash_routing_tensors(self) -> list[str]:
consumed: list[str] = []
for bid in range(self.hparams["num_hash_layers"]):
name = f"layers.{bid}.ffn.gate.tid2eid"
if name not in self.model_tensors:
raise KeyError(f"Missing hash routing tensor {name}")
data_torch = LazyTorchTensor.to_eager(self.model_tensors[name]())
data = data_torch.to(torch.int32).cpu().numpy()
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_TID2EID, bid, ".weight")
logger.info(f"{new_name}: converted hash routing table to I32, shape = {{{', '.join(str(n) for n in reversed(data.shape))}}}")
self.gguf_writer.add_tensor(new_name, data)
consumed.append(name)
return consumed
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
if self._dsv4_mxfp4_generated:
return ()
consumed: list[str] = self._write_hash_routing_tensors()
for bid in range(self.block_count):
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w1", gguf.MODEL_TENSOR.FFN_GATE_EXP))
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP))
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w3", gguf.MODEL_TENSOR.FFN_UP_EXP))
for name in consumed:
del self.model_tensors[name]
self._dsv4_mxfp4_generated = True
return ()
def _format_dsv4_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> str:
return self.format_tensor_name(key, bid, suffix)
def _map_dsv4_tensor_name(self, name: str, bid: int | None) -> tuple[gguf.MODEL_TENSOR, str]:
root_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
"embed.weight": (gguf.MODEL_TENSOR.TOKEN_EMBD, ".weight"),
"norm.weight": (gguf.MODEL_TENSOR.OUTPUT_NORM, ".weight"),
"head.weight": (gguf.MODEL_TENSOR.OUTPUT, ".weight"),
"hc_head_fn": (gguf.MODEL_TENSOR.HC_HEAD_FN, ".weight"),
"hc_head_base": (gguf.MODEL_TENSOR.HC_HEAD_BASE, ".weight"),
"hc_head_scale": (gguf.MODEL_TENSOR.HC_HEAD_SCALE, ".weight"),
}
if name in root_map:
return root_map[name]
match = re.match(r"layers\.(\d+)\.(.+)$", name)
if match is None:
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
layer = int(match.group(1))
if bid != layer:
raise ValueError(f"Tensor {name!r} parsed bid {bid} but layer name has {layer}")
layer_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
"hc_attn_fn": (gguf.MODEL_TENSOR.HC_ATTN_FN, ".weight"),
"hc_attn_base": (gguf.MODEL_TENSOR.HC_ATTN_BASE, ".weight"),
"hc_attn_scale": (gguf.MODEL_TENSOR.HC_ATTN_SCALE, ".weight"),
"hc_ffn_fn": (gguf.MODEL_TENSOR.HC_FFN_FN, ".weight"),
"hc_ffn_base": (gguf.MODEL_TENSOR.HC_FFN_BASE, ".weight"),
"hc_ffn_scale": (gguf.MODEL_TENSOR.HC_FFN_SCALE, ".weight"),
"attn.attn_sink": (gguf.MODEL_TENSOR.ATTN_SINKS, ".weight"),
"attn.wq_a.weight": (gguf.MODEL_TENSOR.ATTN_Q_A, ".weight"),
"attn.wq_b.weight": (gguf.MODEL_TENSOR.ATTN_Q_B, ".weight"),
"attn.q_norm.weight": (gguf.MODEL_TENSOR.ATTN_Q_A_NORM, ".weight"),
"attn.wkv.weight": (gguf.MODEL_TENSOR.ATTN_KV, ".weight"),
"attn.kv_norm.weight": (gguf.MODEL_TENSOR.ATTN_KV_NORM, ".weight"),
"attn.wo_a.weight": (gguf.MODEL_TENSOR.ATTN_OUT_A, ".weight"),
"attn.wo_b.weight": (gguf.MODEL_TENSOR.ATTN_OUT_B, ".weight"),
"attn.compressor.ape": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_APE, ".weight"),
"attn.compressor.wkv.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WKV, ".weight"),
"attn.compressor.wgate.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WGATE, ".weight"),
"attn.compressor.norm.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_NORM, ".weight"),
"attn.indexer.wq_b.weight": (gguf.MODEL_TENSOR.INDEXER_ATTN_Q_B, ".weight"),
"attn.indexer.weights_proj.weight": (gguf.MODEL_TENSOR.INDEXER_PROJ, ".weight"),
"attn.indexer.compressor.ape": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_APE, ".weight"),
"attn.indexer.compressor.wkv.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WKV, ".weight"),
"attn.indexer.compressor.wgate.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE, ".weight"),
"attn.indexer.compressor.norm.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_NORM, ".weight"),
"attn_norm.weight": (gguf.MODEL_TENSOR.ATTN_NORM, ".weight"),
"ffn_norm.weight": (gguf.MODEL_TENSOR.FFN_NORM, ".weight"),
"ffn.gate.weight": (gguf.MODEL_TENSOR.FFN_GATE_INP, ".weight"),
"ffn.gate.bias": (gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, ".bias"),
"ffn.gate.tid2eid": (gguf.MODEL_TENSOR.FFN_GATE_TID2EID, ".weight"),
"ffn.shared_experts.w1.weight": (gguf.MODEL_TENSOR.FFN_GATE_SHEXP, ".weight"),
"ffn.shared_experts.w2.weight": (gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, ".weight"),
"ffn.shared_experts.w3.weight": (gguf.MODEL_TENSOR.FFN_UP_SHEXP, ".weight"),
}
tensor_name = match.group(2)
if tensor_name in layer_map:
return layer_map[tensor_name]
if re.match(r"ffn\.experts\.\d+\.w[123]\.(weight|scale)$", tensor_name):
return gguf.MODEL_TENSOR.FFN_GATE_EXP, ".weight"
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if re.match(r"layers\.\d+\.ffn\.experts\.\d+\.w[123]\.(weight|scale)$", name):
return []
tensor_key, suffix = self._map_dsv4_tensor_name(name, bid)
if tensor_key == gguf.MODEL_TENSOR.FFN_GATE_TID2EID:
return []
return [(self._format_dsv4_tensor_name(tensor_key, bid, suffix), data_torch)]
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del new_name, bid # unused
if name in self._dsv4_fp8_dequantized and n_dims >= 2:
return gguf.GGMLQuantizationType.Q8_0
if name in self._dsv4_f32_tensors:
return gguf.GGMLQuantizationType.F32
if name in self._dsv4_bf16_tensors and n_dims >= 2:
return gguf.GGMLQuantizationType.BF16
return False
def prepare_tensors(self):
super().prepare_tensors()
self._is_mxfp4 = True
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
+3 -7
View File
@@ -1551,8 +1551,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id];
ggml_backend_synchronize(split_backend);
// copy the input tensors to the split backend
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
@@ -1563,15 +1561,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else if (!split_backend->iface.cpy_tensor_async) {
} else {
ggml_backend_synchronize(split_backend);
}
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
ggml_backend_tensor_copy(input, input_cpy);
} else {
// wait for the split backend to finish using the input before overwriting it
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
} else if (!split_backend->iface.cpy_tensor_async) {
} else {
ggml_backend_synchronize(split_backend);
}
@@ -1676,8 +1674,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}
}
ggml_backend_synchronize(split_backend);
if (!sched->callback_eval) {
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
if (ec != GGML_STATUS_SUCCESS) {
+3 -2
View File
@@ -1111,11 +1111,12 @@ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
// e2m1 values (doubled)
// e2m1 values (doubled), shared by MXFP4 and NVFP4
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
GGML_TABLE_BEGIN(int8_t, kvalues_fp4, 16)
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
GGML_TABLE_END()
#define kvalues_mxfp4 kvalues_fp4
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
-1
View File
@@ -82,7 +82,6 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
+142 -4
View File
@@ -934,7 +934,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#if defined __AVX2__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
@@ -963,7 +963,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
@@ -993,14 +993,152 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
sumi1 += y[ib].qs[j + 0] * kvalues_fp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_fp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_NVFP4 == 0);
const block_nvfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_NVFP4;
int ib = 0;
float sumf = 0;
#if defined(__AVX2__)
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
__m256 accum = _mm256_setzero_ps();
for(; ib < nb; ib++){
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
const __m256i q8_01 = _mm256_loadu_si256((const __m256i *)y[2*ib + 0].qs);
const __m256i q8_23 = _mm256_loadu_si256((const __m256i *)y[2*ib + 1].qs);
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
//reordering
const __m256i q4_01 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_01_lo,q4_01_hi), _mm_unpacklo_epi64(q4_01_lo,q4_01_hi));
const __m256i q4_23 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_23_lo,q4_23_hi),_mm_unpacklo_epi64(q4_23_lo,q4_23_hi));
const __m256i p01 = mul_add_epi8(q4_01,q8_01);
const __m256i p_1 = _mm256_madd_epi16(p01, mone);
const __m256i p23 = mul_add_epi8(q4_23,q8_23);
const __m256i p_2 = _mm256_madd_epi16(p23, mone);
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
accum = _mm256_fmadd_ps(scales01, _mm256_cvtepi32_ps(p_1), accum);
accum = _mm256_fmadd_ps(scales23, _mm256_cvtepi32_ps(p_2), accum);
}
sumf = hsum_float_8(accum);
#elif defined(__AVX__)
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
for(; ib < nb; ib++){
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
const __m128i q8_0 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 0));
const __m128i q8_1 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 16));
const __m128i q8_2 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 0));
const __m128i q8_3 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 16));
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
const __m128i q4_0 = _mm_unpacklo_epi64(q4_01_lo, q4_01_hi);
const __m128i q4_1 = _mm_unpackhi_epi64(q4_01_lo, q4_01_hi);
const __m128i q4_2 = _mm_unpacklo_epi64(q4_23_lo, q4_23_hi);
const __m128i q4_3 = _mm_unpackhi_epi64(q4_23_lo, q4_23_hi);
const __m128i p0_i32 = mul_sum_i8_pairs(q4_0, q8_0);
const __m128i p1_i32 = mul_sum_i8_pairs(q4_1, q8_1);
const __m128i p2_i32 = mul_sum_i8_pairs(q4_2, q8_2);
const __m128i p3_i32 = mul_sum_i8_pairs(q4_3, q8_3);
const __m128 p0 = _mm_cvtepi32_ps(p0_i32);
const __m128 p1 = _mm_cvtepi32_ps(p1_i32);
const __m128 p2 = _mm_cvtepi32_ps(p2_i32);
const __m128 p3 = _mm_cvtepi32_ps(p3_i32);
const __m256 p01 = _mm256_set_m128(p1, p0);
const __m256 p23 = _mm256_set_m128(p3, p2);
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
accum = _mm256_add_ps(accum, _mm256_mul_ps(p01, scales01));
accum = _mm256_add_ps(accum, _mm256_mul_ps(p23, scales23));
}
sumf = hsum_float_8(accum);
#endif
for (;ib < nb; ++ib) {
for (int s_idx = 0; s_idx < 4; ++s_idx) {
const float d = GGML_CPU_UE4M3_TO_FP32(x[ib].d[s_idx]);
const int q8_block = s_idx / 2;
const int q8_off = (s_idx % 2) * QK_NVFP4_SUB;
const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);
int sumi_lo = 0, sumi_hi = 0;
for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];
sumi_lo += y[2*ib + q8_block].qs[q8_off + j + 0] * kvalues_fp4[qv & 0xf];
sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_fp4[qv >> 4];
}
sumf += dy * d * (sumi_lo + sumi_hi);
}
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+8
View File
@@ -82,6 +82,9 @@ float ggml_table_f32_f16[1 << 16];
// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)
float ggml_table_f32_e8m0_half[1 << 8];
// precomputed f32 table for ue4m3 (1 KB) (simd-mappings.h)
float ggml_table_f32_ue4m3[1 << 8];
#if defined(__ARM_ARCH)
struct ggml_arm_arch_features_type {
int sve_cnt;
@@ -3798,6 +3801,11 @@ void ggml_cpu_init(void) {
ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);
}
// initialize UE4M3 table (256 entries)
for (int i = 0; i < (1 << 8); ++i) {
ggml_table_f32_ue4m3[i] = ggml_ue4m3_to_fp32(i);
}
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
+11
View File
@@ -120,6 +120,10 @@ extern float ggml_table_f32_f16[1 << 16];
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_e8m0_half[1 << 8];
// precomputed f32 table for ue4m3 (1 KB)
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_ue4m3[1 << 8];
// Use lookup table for E8M0 on x86 (faster than bit manipulation)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]
@@ -127,6 +131,13 @@ extern float ggml_table_f32_e8m0_half[1 << 8];
#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)
#endif
// Use lookup table for UE4M3 on x86 (faster than bit manipulation)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_table_f32_ue4m3[(uint8_t)(x)]
#else
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_ue4m3_to_fp32(x)
#endif
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9.
+9 -5
View File
@@ -664,7 +664,10 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
template <int ncols1>
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
static __global__ void flash_attn_mask_to_KV_max(
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
const half2 * mask_ptr, int * KV_max_ptr, const int ne30, const int64_t s31, const int64_t s33) {
const half2 * GGML_CUDA_RESTRICT mask = mask_ptr;
int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr;
const int ne31 = gridDim.x;
const int tid = threadIdx.x;
const int sequence = blockIdx.y;
@@ -1089,8 +1092,8 @@ void launch_fattn(
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
// multiple sequences of possibly different lengths.
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
const int64_t s31 = mask->nb[1] / sizeof(half2);
const int64_t s33 = mask->nb[3] / sizeof(half2);
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
@@ -1099,8 +1102,9 @@ void launch_fattn(
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
KV_max.alloc(ne_KV_max);
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_KV_max, block_dim_KV_max, 0, main_stream);
ggml_cuda_kernel_launch(flash_attn_mask_to_KV_max<ncols1>, launch_params,
(const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
CUDA_CHECK(cudaGetLastError());
}
+4
View File
@@ -2003,6 +2003,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 32, 2);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
+9 -5
View File
@@ -76,6 +76,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -144,6 +145,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
@@ -219,6 +221,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
@@ -296,6 +299,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
@@ -1308,12 +1312,12 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
return;
}
if constexpr (DV <= 256) {
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
return;
}
if constexpr (DV <= 256) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
return;
}
+5 -5
View File
@@ -99,12 +99,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
return;
}
if constexpr (DKQ <= 256) {
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
if constexpr (DKQ <= 256) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
} else {
GGML_ABORT("fatal error");
+17 -14
View File
@@ -78,26 +78,29 @@ static __global__ void k_get_rows_float(
template<typename grad_t, typename dst_t>
static __global__ void k_get_rows_back_float(
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst,
const int64_t ncols, const int64_t nrows_grad, const int64_t nrows_dst) {
const int col = blockIdx.x*blockDim.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
float sum = 0.0f;
ggml_cuda_pdl_sync();
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
}
sum += grad[i*ncols + col];
}
dst[dst_row*ncols + col] = sum;
// grid.y is clamped to the CUDA grid limit, so stride over the destination rows
for (int64_t dst_row = blockIdx.y; dst_row < nrows_dst; dst_row += gridDim.y) {
float sum = 0.0f;
for (int64_t i = 0; i < nrows_grad; ++i) {
if (rows[i] != dst_row) {
continue;
}
sum += grad[i*ncols + col];
}
dst[dst_row*ncols + col] = sum;
}
}
template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
@@ -302,7 +305,7 @@ void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne1, 1);
const dim3 block_nums(block_num_x, MIN(ne1, (int64_t)UINT16_MAX), 1);
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10, ne1);
}
+4 -20
View File
@@ -3192,24 +3192,11 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
// Enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA
// Excluding this path for HIP and MUSA as a precaution.
// According to the summary in https://github.com/ggml-org/llama.cpp/pull/20793#issuecomment-4275794315, this change is not beneficial for hip anyways.
// Additionally, there is a lot of anectodal evidence that hip/musa stream behavior might not always 1:1 match CUDA behavior.
// e.g. https://github.com/ROCm/rocm-systems/issues/5109
// It thus makes sense to exclude this path for HIP and MUSA. This PR was not aimed these backends, the majority of testing happened on CUDA.
// This can be revisited in the future if enabling copy_from_host benefits hip/MUSA, and if the PR author can extensively test on these backends.
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA)
const bool copy_from_host = false;
#else
const bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU;
#endif
if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) {
if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
return false;
}
if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(buf_dst)) {
if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
return false;
}
@@ -3220,17 +3207,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) ||
!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) {
if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
#endif // NDEBUG
return false;
}
if (copy_from_host) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream()));
} else if (backend_src != backend_dst) {
if (backend_src != backend_dst) {
// copy on src stream
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
+7
View File
@@ -368,5 +368,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
return true;
}
// gfx900 (Vega 10) lacks native dp4a, loses to dequant + hipBLAS
// for dense matrices; keep MMQ only for MoE, where the
// hipBLAS path is much slower.
if (cc == GGML_CUDA_CC_VEGA) {
return n_experts > 0;
}
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 16, 2);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 32, 2);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 2);
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 2);
@@ -92,7 +92,7 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
continue
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
if head_size_kq == 512 and ncols2 not in (2, 4, 8): # Gemma 4 (+ MTP)
continue
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
continue
-1
View File
@@ -23,7 +23,6 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
add_library(htp_iface OBJECT
+221 -26
View File
@@ -43,6 +43,7 @@
#include "htp-opnode.h"
#include "htp-ops.h"
#include "htp/matmul-ops.h"
#include "htp/flash-attn-ops.h"
#include "htp_iface.h"
#include "htp-drv.h"
@@ -62,6 +63,7 @@ static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
static int opt_hostbuf = 1; // hostbuf ON by default
static int opt_mm_select = 3; // 3 = HMX -> Tiled -> Flat -> CPU, 2 = Tiled -> Flat -> CPU, 1 = Flat -> CPU
static int opt_fa_select = 2; // 2 = HMX -> HVX -> CPU, 1 = HVX -> CPU, 0 = CPU (unsupported)
// Default PMU events, if profiling with PMU (mode=2) is enabled
// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html
@@ -125,6 +127,11 @@ static const char * htp_event_name(uint16_t id) {
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
case HTP_TRACE_EVT_HVX_FA_QK: return "HVX_QK_FA";
case HTP_TRACE_EVT_HVX_FA_SFM: return "HVX_SFM_FA";
case HTP_TRACE_EVT_HVX_FA_Q_PREP: return "HVX_Q_PREP";
case HTP_TRACE_EVT_HVX_FA_K_PREP: return "HVX_K_PREP";
case HTP_TRACE_EVT_HVX_FA_V_PREP: return "HVX_V_PREP";
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
default: return "UNKNOWN";
}
@@ -1879,6 +1886,162 @@ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
// ** backend interface
static bool ggml_hexagon_flash_attn_is_hmx_eligible(
const struct ggml_hexagon_session * sess,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * sinks
) {
if (sess->n_hmx == 0) {
return false;
}
if (opt_fa_select < 2) {
return false;
}
if (k->type != GGML_TYPE_F16 || v->type != GGML_TYPE_F16) {
return false;
}
const uint32_t DK = q->ne[0];
const uint32_t DV = v->ne[0];
if (DK % 64 != 0 || DV % 64 != 0) {
return false;
}
// Fall back to HVX for small token counts if head dimension is small (DK <= 128)
const uint32_t neq1 = q->ne[1];
if (DK <= 128 && neq1 < 5) {
return false;
}
return true;
}
static bool ggml_hexagon_precompute_flash_attn_params(
const struct ggml_hexagon_session * sess,
const struct ggml_tensor * op,
struct htp_fa_kernel_params * kparams
) {
if (opt_fa_select < 1) {
return false;
}
memset(kparams, 0, sizeof(*kparams));
const struct ggml_tensor * q = op->src[0];
const struct ggml_tensor * k = op->src[1];
const struct ggml_tensor * v = op->src[2];
const struct ggml_tensor * mask = op->src[3];
const struct ggml_tensor * dst = op;
const uint32_t neq0 = q->ne[0]; // head_dim (DK)
const uint32_t neq1 = q->ne[1]; // n_tokens
const uint32_t neq2 = q->ne[2]; // n_heads
const uint32_t nek1 = k->ne[1]; // kv_len
const uint32_t nev0 = v->ne[0]; // head_dim (DV)
const uint32_t DK = neq0;
const uint32_t DV = nev0;
const uint32_t n_kv_heads = k->ne[2];
const uint32_t G = neq2 / n_kv_heads;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, &op->op_params[0], sizeof(float));
memcpy(&max_bias, &op->op_params[1], sizeof(float));
memcpy(&logit_softcap, &op->op_params[2], sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
kparams->scale = scale;
kparams->max_bias = max_bias;
kparams->logit_softcap = logit_softcap;
kparams->is_q_fp32 = (q->type == GGML_TYPE_F32) ? 1 : 0;
kparams->is_dst_fp32 = (dst->type == GGML_TYPE_F32) ? 1 : 0;
kparams->G = G;
const uint32_t n_head = q->ne[2];
kparams->n_head_log2 = 1u << (uint32_t) std::floor(std::log2(n_head));
kparams->m0 = std::pow(2.0f, -(max_bias) / kparams->n_head_log2);
kparams->m1 = std::pow(2.0f, -(max_bias / 2.0f) / kparams->n_head_log2);
// Check HMX eligibility
const struct ggml_tensor * sinks = op->src[4];
if (ggml_hexagon_flash_attn_is_hmx_eligible(sess, q, k, v, sinks)) {
size_t Br = 0, Bc = 0;
int ret = hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, sess->vtcm_size, sess->n_threads);
if (ret == 0) {
kparams->kernel_type = HTP_FA_KERNEL_HMX;
kparams->Br = Br;
kparams->Bc = Bc;
kparams->n_kv_blocks = (nek1 + Bc - 1) / Bc;
kparams->n_threads = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? sess->n_threads : 1;
kparams->u.hmx.g_br = hex_align_up(G * Br, 32);
kparams->u.hmx.pipeline = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? 1 : 0;
kparams->vtcm_size = hmx_fa_compute_vtcm_usage(G, DK, DV, Br, Bc, kparams->n_threads, kparams->u.hmx.pipeline != 0);
const size_t row_vec_bytes = hex_align_up(Bc * sizeof(uint16_t), 256);
kparams->u.hmx.row_buf_stride = row_vec_bytes / 128; // HVX vector is 128 bytes
const size_t m_line_bytes = hex_align_up(Bc * sizeof(uint16_t), 128);
kparams->u.hmx.mask_buf_row_stride = m_line_bytes / sizeof(uint16_t);
kparams->u.hmx.mask_broadcast = (mask != nullptr && mask->ne[2] == 1) ? 1 : 0;
kparams->u.hmx.div_G = init_fastdiv_values(G);
if (mask) {
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
kparams->qrows = 0;
kparams->qrows_per_thread = 0;
return true;
}
}
// Fallback to HVX
kparams->kernel_type = HTP_FA_KERNEL_HVX;
kparams->Br = 1;
kparams->Bc = 64; // FLASH_ATTN_BLOCK_SIZE
kparams->n_kv_blocks = (k->ne[1] + 64 - 1) / 64;
kparams->n_threads = sess->n_threads;
const size_t size_q_row_padded = hex_round_up(q->ne[0] * (kparams->is_q_fp32 ? 4 : 2), 128);
const size_t size_k_row_padded = hex_round_up(k->ne[0] * 2, 128);
const size_t size_v_row_padded = hex_round_up(v->ne[0] * 2, 128);
kparams->vtcm_size = hvx_fa_compute_vtcm_usage(DK, DV, kparams->is_q_fp32 != 0, mask != nullptr, sess->n_threads);
kparams->u.hvx.size_q_row_padded = size_q_row_padded;
kparams->u.hvx.size_k_row_padded = size_k_row_padded;
kparams->u.hvx.size_v_row_padded = size_v_row_padded;
kparams->u.hvx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
kparams->u.hvx.src0_div1 = init_fastdiv_values(q->ne[1]);
kparams->u.hvx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
kparams->u.hvx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
kparams->u.hvx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
kparams->u.hvx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
kparams->qrows = q->ne[1] * q->ne[2] * q->ne[3];
kparams->qrows_per_thread = (kparams->qrows + sess->n_threads - 1) / sess->n_threads;
return true;
}
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
@@ -1912,6 +2075,17 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
return false;
}
struct htp_fa_kernel_params kparams;
if (!ggml_hexagon_precompute_flash_attn_params(sess, op, &kparams)) {
return false;
}
if ((size_t) kparams.vtcm_size > sess->vtcm_size) {
HEX_VERBOSE("ggml-hex: skip flash_attn_ext because VTCM needed (%d) > budget (%zu)\n",
kparams.vtcm_size, sess->vtcm_size);
return false;
}
return true;
}
@@ -2211,14 +2385,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
kparams->kernel_type = (src1_nrows < (int) sess->n_threads) ? HTP_MM_KERNEL_HVX_QUANT_BLOCK : HTP_MM_KERNEL_HVX_QUANT_ROW;
kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
size_t vtcm_src0_size = 0, vtcm_src1_size = 0;
size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0;
uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16;
uint32_t best_n_prefetch = 2;
size_t total_size = 0;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
total_size = htp_mm_hvx_id_get_vtcm_sizes(
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], d,
&vtcm_src0_size, &vtcm_src1_size
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
);
if (total_size <= vtcm_budget) {
best_n_prefetch = d;
@@ -2228,14 +2402,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
if (best_n_prefetch == 2 && total_size > vtcm_budget) {
total_size = htp_mm_hvx_id_get_vtcm_sizes(
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], 2,
&vtcm_src0_size, &vtcm_src1_size
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
);
}
kparams->n_prefetch = best_n_prefetch;
kparams->vtcm_size = total_size;
kparams->vtcm_src0_size = vtcm_src0_size;
kparams->vtcm_src1_size = vtcm_src1_size;
kparams->vtcm_dst_size = 0;
kparams->vtcm_dst_size = vtcm_dst_size;
} else {
bool try_tiled = (k_align && opt_mm_select >= 2);
if (try_tiled) {
@@ -2441,11 +2615,12 @@ static void ggml_hexagon_precompute_fused_qkv_params(
size_t src3_sz_per_thread = 0;
uint32_t best_n_prefetch = 16;
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
size_t src1_sz = src1_sz_per_thread;
@@ -2453,13 +2628,10 @@ static void ggml_hexagon_precompute_fused_qkv_params(
best_n_prefetch = 2;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t src3_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
if (tiled_vtcm_size <= sess->vtcm_size) {
best_n_prefetch = d;
@@ -2471,9 +2643,6 @@ static void ggml_hexagon_precompute_fused_qkv_params(
}
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
src3_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
@@ -2492,7 +2661,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
size_t src3_sz = src3_sz_per_thread * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
bool try_tiled = (opt_mm_select >= 2);
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
@@ -2500,6 +2669,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
kparams->vtcm_src1_size = src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_src3_size = src3_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = tiled_vtcm_size;
kparams->n_prefetch = best_n_prefetch;
} else {
@@ -2510,7 +2680,8 @@ static void ggml_hexagon_precompute_fused_qkv_params(
kparams->vtcm_src1_size = flat_src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_src3_size = src3_sz;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz + quant_scratch_size;
kparams->n_prefetch = best_n_prefetch;
}
}
@@ -2536,11 +2707,12 @@ static void ggml_hexagon_precompute_fused_ffn_params(
size_t src2_sz_per_thread = 0;
uint32_t best_n_prefetch = 16;
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
size_t src1_sz = src1_sz_per_thread;
@@ -2548,12 +2720,9 @@ static void ggml_hexagon_precompute_fused_ffn_params(
best_n_prefetch = 2;
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
if (tiled_vtcm_size <= sess->vtcm_size) {
best_n_prefetch = d;
@@ -2564,9 +2733,6 @@ static void ggml_hexagon_precompute_fused_ffn_params(
}
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
}
@@ -2582,13 +2748,14 @@ static void ggml_hexagon_precompute_fused_ffn_params(
size_t src1_sz = src1_sz_per_thread;
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
bool try_tiled = (opt_mm_select >= 2);
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
kparams->vtcm_src0_size = src0_sz;
kparams->vtcm_src1_size = src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = tiled_vtcm_size;
kparams->n_prefetch = best_n_prefetch;
} else {
@@ -2598,7 +2765,8 @@ static void ggml_hexagon_precompute_fused_ffn_params(
kparams->vtcm_src0_size = src0_sz;
kparams->vtcm_src1_size = flat_src1_sz;
kparams->vtcm_src2_size = src2_sz;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz;
kparams->vtcm_dst_size = quant_scratch_size;
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + quant_scratch_size;
kparams->n_prefetch = best_n_prefetch;
}
}
@@ -3243,7 +3411,7 @@ static inline bool op_is_compute(ggml_tensor *node)
return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
}
static bool is_hmx_eligible(const ggml_tensor * t) {
static bool mm_is_hmx_eligible(const ggml_tensor * t) {
if (opt_nhmx == 0) { return false; }
const ggml_tensor * src0 = t->src[0];
@@ -3262,7 +3430,7 @@ static bool is_hmx_eligible(const ggml_tensor * t) {
static bool is_mergeable_mul_mat(const ggml_tensor * t) {
if (!t || t->op != GGML_OP_MUL_MAT) return false;
if (t->src[1]->type != GGML_TYPE_F32) return false;
return ggml_is_quantized(t->src[0]->type) && !is_hmx_eligible(t);
return ggml_is_quantized(t->src[0]->type) && !mm_is_hmx_eligible(t);
}
static bool is_mergeable_mul_mat_pair(const ggml_tensor * n1, const ggml_tensor * n2) {
@@ -3357,6 +3525,26 @@ static bool try_fuse_node(const ggml_hexagon_session * sess, const ggml_cgraph *
}
}
if (n->op == GGML_OP_MUL_MAT && next_node) {
if (next_node->op == GGML_OP_ADD && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
if (next_node->src[0] == n || next_node->src[1] == n) {
struct htp_mm_kernel_params kparams;
ggml_hexagon_precompute_matmul_params(sess, n->src[0], n->src[1], next_node, &kparams);
if ((size_t)kparams.vtcm_size <= sess->vtcm_size) {
htp_opnode node(n, {}, HTP_OP_MUL_MAT_ADD);
node.add_fused(next_node);
memcpy(node.kernel_params, &kparams, sizeof(kparams));
nodes.push_back(std::move(node));
i += 1;
return true;
} else {
HEX_VERBOSE("ggml-hex: skip MUL_MAT_ADD fusion because VTCM needed (%d) > budget (%zu)\n",
kparams.vtcm_size, sess->vtcm_size);
}
}
}
}
return false;
}
@@ -3393,6 +3581,11 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
node.node->src[0], node.node->src[1], node.node,
(struct htp_mm_kernel_params *)node.kernel_params
);
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
ggml_hexagon_precompute_flash_attn_params(sess,
node.node,
(struct htp_fa_kernel_params *)node.kernel_params
);
}
computed_nodes.push_back(std::move(node));
}
@@ -4079,6 +4272,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
const char * str_nhmx = getenv("GGML_HEXAGON_NHMX");
const char * str_mm_select = getenv("GGML_HEXAGON_MM_SELECT");
const char * str_fa_select = getenv("GGML_HEXAGON_FA_SELECT");
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
const char * str_vmem = getenv("GGML_HEXAGON_VMEM");
@@ -4120,6 +4314,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
opt_nhmx = str_nhmx ? atoi(str_nhmx) : (str_use_hmx ? atoi(str_use_hmx) : opt_nhmx);
opt_mm_select = str_mm_select ? atoi(str_mm_select) : opt_mm_select;
opt_fa_select = str_fa_select ? atoi(str_fa_select) : opt_fa_select;
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf;
+13 -1
View File
@@ -11,6 +11,7 @@
#include <stdio.h>
#include "htp-ops.h"
#include "htp/matmul-ops.h"
#include "htp/flash-attn-ops.h"
struct htp_opnode {
ggml_tensor * node = nullptr;
@@ -335,7 +336,8 @@ struct htp_opformat {
}
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN ||
node.opcode == HTP_OP_MUL_MAT_ADD) {
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
@@ -350,6 +352,16 @@ struct htp_opformat {
path = "hvx-flat";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
const auto * kparams = (const struct htp_fa_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
if (type == HTP_FA_KERNEL_HMX) {
path = kparams->u.hmx.pipeline ? "hmx-pipe" : "hmx-seq";
} else if (type == HTP_FA_KERNEL_HVX) {
path = "hvx";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else {
snprintf(str, max_size, "----");
}
+2 -7
View File
@@ -20,9 +20,6 @@ add_library(${HTP_LIB} SHARED
worker-pool.c
hex-dma.c
hmx-queue.c
flash-attn-ops.c
hmx-flash-attn-ops.c
matmul-ops.c
binary-ops.c
unary-ops.c
sum-rows-ops.c
@@ -42,16 +39,14 @@ add_library(${HTP_LIB} SHARED
solve-tri-ops.c
gated-delta-net-ops.c
pad-ops.c
matmul-ops.c
flash-attn-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
File diff suppressed because it is too large Load Diff
+253
View File
@@ -0,0 +1,253 @@
#ifndef HTP_FLASH_ATTN_OPS_H
#define HTP_FLASH_ATTN_OPS_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#include "hex-fastdiv.h"
#include "hex-common.h"
#ifdef __cplusplus
extern "C" {
#endif
// Tile constants (mirrored from hmx-utils.h for use on host side if needed)
#define HMX_FP16_TILE_N_ROWS 32
#define HMX_FP16_TILE_N_COLS 32
#define HMX_FP16_TILE_N_ELMS 1024
#define HMX_FP16_TILE_SIZE 2048
#define HVX_FA_DMA_CACHE_SIZE 128
#define HMX_FA_DMA_CACHE_SIZE 4
#define HTP_FA_M_INITIAL_VAL -10000.0f
enum htp_fa_kernel_type {
HTP_FA_KERNEL_UNSUPPORTED = 0,
HTP_FA_KERNEL_HVX,
HTP_FA_KERNEL_HMX
};
struct htp_fa_kernel_params {
uint8_t kernel_type; // enum htp_fa_kernel_type
uint8_t is_q_fp32; // 1 = Q type is F32, 0 = F16
uint8_t is_dst_fp32; // 1 = dst type is F32, 0 = F16
uint8_t n_threads; // Number of threads to run
// Common parameters
uint16_t Br;
uint16_t Bc;
uint16_t n_kv_blocks; // also HVX's n_blocks
uint16_t G; // GQA factor (n_heads / n_kv_heads)
float scale;
float max_bias;
float logit_softcap;
uint32_t vtcm_size;
uint32_t qrows;
uint32_t qrows_per_thread;
float m0;
float m1;
uint32_t n_head_log2;
struct fastdiv_values src3_div2;
struct fastdiv_values src3_div3;
union {
struct {
uint32_t g_br;
uint32_t row_buf_stride;
uint32_t mask_buf_row_stride;
int32_t mask_broadcast;
int32_t pipeline;
struct fastdiv_values div_G;
} hmx;
struct {
uint32_t size_q_row_padded;
uint32_t size_k_row_padded;
uint32_t size_v_row_padded;
struct fastdiv_values src0_div21;
struct fastdiv_values src0_div1;
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
} hvx;
} u;
};
#if defined(__cplusplus)
static_assert(sizeof(struct htp_fa_kernel_params) <= 128, "htp_fa_kernel_params is too large for kernel_params blob");
#endif
// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration.
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
static inline size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf
const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf
const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved
const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved
const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc]
const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br]
const size_t col_vec_size = hex_align_up(g_br * sizeof(float), 256); // m, l, etc.
const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256);
const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128);
const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096) * HMX_FA_DMA_CACHE_SIZE;
const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128);
return q_tile_size * 1 // Q tiles
+ o_tile_size * 2 // O ping-pong
+ k_dma_size * 2 // K DMA x2
+ v_dma_size * 2 // V DMA x2
+ k_tile_size * 1 // K tiles
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ s_tile_size * 2 // S + P
+ d_tile_size * 1 // D (diagonal matrix)
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
+ row_vec_size * 2 * n_threads // per-thread softmax row scratch
+ m_buf_size * 1 // mask VTCM buffer [Br rows]
+ slopes_size // Slopes
+ 256 * 2; // HMX scales (id + qk)
}
#define FA_HVX_BLOCK_SIZE 64
static inline size_t hvx_fa_compute_vtcm_usage(size_t DK, size_t DV, bool is_q_fp32, bool has_mask, size_t n_threads) {
const size_t size_q_row_padded = hex_round_up(DK * (is_q_fp32 ? 4 : 2), 128);
const size_t size_k_row_padded = hex_round_up(DK * sizeof(__fp16), 128);
const size_t size_v_row_padded = hex_round_up(DV * sizeof(__fp16), 128);
const size_t size_q_block = size_q_row_padded * 1;
const size_t size_k_block = size_k_row_padded * FA_HVX_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FA_HVX_BLOCK_SIZE;
const size_t size_m_block = hex_round_up(FA_HVX_BLOCK_SIZE * sizeof(__fp16), 128);
const size_t size_vkq_acc = hex_round_up(DV * sizeof(float), 128);
const size_t size_per_thread = size_q_block * 1
+ size_k_block * 2
+ size_v_block * 2
+ (has_mask ? size_m_block * HVX_FA_DMA_CACHE_SIZE : 0)
+ size_vkq_acc;
return size_per_thread * n_threads;
}
#define FA_MIN_KV_BLOCKS 3
// Cost-based (Br, Bc) search for flash attention with pipeline constraint.
static inline int hmx_fa_find_chunk_size(size_t * Br_out,
size_t * Bc_out,
size_t gqa_factor,
size_t DK,
size_t DV,
size_t qo_len,
size_t kv_len,
size_t vtcm_budget,
size_t n_threads) {
const size_t T = HMX_FP16_TILE_N_ROWS; // 32
const size_t br_unit = hmx_ceil_div(T, gqa_factor);
const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64
const size_t fp16 = sizeof(__fp16);
const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2);
// Approximate per-unit VTCM costs (without per-buffer alignment padding).
const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * sizeof(float); // Q + O*2 + 4 col vectors
const size_t per_gbr2 = fp16; // D diagonal matrix
const size_t per_bc =
3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs
const size_t per_gbr_bc = 2 * fp16; // S + P
const size_t overhead = 256 * 2 + 13 * 4096;
if (vtcm_budget <= overhead) {
return -1;
}
const size_t usable = vtcm_budget - overhead;
// Br_max: largest Br aligned to br_unit that does not exceed qo_len.
const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit;
// Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS.
// Only relax when kv_len is too short to form enough blocks.
const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) :
(kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
// Cost coefficients calibrated from profiling
const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store
const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers
size_t best_cost = SIZE_MAX, best_mn = 0;
size_t best_Br = 0, best_Bc = 0;
for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) {
const size_t g_br = hex_align_up(gqa_factor * Br, T);
// g_br-dependent VTCM cost: g_br * per_gbr + g_br*g_br * per_gbr2
const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2;
if (gbr_cost >= usable) {
if (Br == br_unit) {
break;
}
continue;
}
// Analytically solve for max Bc:
// remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE)
// The Br * fp16 term accounts for the VTCM mask buffer [Br * Bc].
const size_t remain = usable - gbr_cost;
const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE;
size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit);
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
// Exact VTCM verification (alignment padding may push over budget)
while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) {
Bc -= bc_unit;
}
if (Bc < bc_unit) {
if (Br == br_unit) {
break;
}
continue;
}
const size_t q_blocks = (qo_len + Br - 1) / Br;
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed);
const size_t mn = Br * Bc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_Br = Br;
best_Bc = Bc;
}
if (Br == br_unit) {
break;
}
}
if (best_Br == 0) {
return -1;
}
*Br_out = best_Br;
*Bc_out = best_Bc;
return 0;
}
#ifdef __cplusplus
}
#endif
#endif /* HTP_FLASH_ATTN_OPS_H */
+15 -15
View File
@@ -138,27 +138,28 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
}
dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx];
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 0;
desc->done = 0;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;
q->dptr[q->push_idx] = dptr;
if (size) {
desc->next = NULL;
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 0;
desc->done = 0;
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
} else {
desc->done = 1;
desc->desc_size = 0;
desc->done = 1;
}
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
@@ -320,7 +321,7 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
#define DMA_CACHE_MAX_SIZE 64U
#define DMA_CACHE_MAX_SIZE 256U
typedef struct {
uint8_t *base;
@@ -352,20 +353,19 @@ static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * sr
if (c->src[i] == (uint32_t) src) {
c->age[i] = 0;
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
// FARF(ERROR, "dma-cache: found %p", src);
} else {
c->age[i]++;
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
}
}
if (!dst) {
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
c->age[o_idx] = 0;
c->src[o_idx] = (uint32_t) src;
dst = c->base + o_idx * c->line_size; // normal nrows dma
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
}
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
return dma_queue_push_single_1d(q, dma_make_ptr(dst, src), 0);
}
#ifdef __cplusplus
@@ -0,0 +1,96 @@
#ifndef HMX_FA_KERNELS_H
#define HMX_FA_KERNELS_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#include "hvx-utils.h"
#include "hmx-utils.h"
// HMX-specific parameters, offsets and inner kernels for Flash Attention
// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6
// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile
static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = {
0 * 136, 0 * 136 + 6,
1 * 136, 1 * 136 + 6,
2 * 136, 2 * 136 + 6,
3 * 136, 3 * 136 + 6,
4 * 136, 4 * 136 + 6,
5 * 136, 5 * 136 + 6,
6 * 136, 6 * 136 + 6,
7 * 136, 7 * 136 + 6,
8 * 136, 8 * 136 + 6,
9 * 136, 9 * 136 + 6,
10 * 136, 10 * 136 + 6,
11 * 136, 11 * 136 + 6,
12 * 136, 12 * 136 + 6,
13 * 136, 13 * 136 + 6,
14 * 136, 14 * 136 + 6,
15 * 136, 15 * 136 + 6,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
0, 0,
};
// Inner HMX tile computation kernels
static inline void hmx_fa_qk_dot_tile(
const __fp16 * row_tiles,
const __fp16 * col_tiles,
__fp16 * out_tile,
size_t n_dot_tiles
) {
for (size_t k = 0; k < n_dot_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047);
row_tiles += HMX_FP16_TILE_N_ELMS;
col_tiles += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(out_tile, 0);
}
static inline void hmx_fa_o_update_tile(
const __fp16 * d_diag,
const __fp16 * o_rc,
const __fp16 * p_tile_in,
const __fp16 * v_tile_in,
__fp16 * o_tile_out,
size_t n_col_tiles
) {
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
for (size_t k = 0; k < n_col_tiles; ++k) {
Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047);
p_tile_in += HMX_FP16_TILE_N_ELMS;
v_tile_in += HMX_FP16_TILE_N_ELMS;
}
Q6_mxmem_AR_after_hf(o_tile_out, 0);
}
static inline void hmx_fa_o_norm_tile(
const __fp16 * d_diag,
const __fp16 * o_rc,
__fp16 * o_out
) {
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
Q6_mxmem_AR_after_hf(o_out, 0);
}
#endif /* HMX_FA_KERNELS_H */
File diff suppressed because it is too large Load Diff
@@ -712,7 +712,17 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
// output : fp16 -> f32p
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, uint32_t start_row, uint32_t n_rows, uint32_t n_cols, uint32_t dst_stride, uint32_t dst_cols) {
static void transfer_output_chunk_fp16_to_fp32(
float *restrict dst,
const float *restrict src2,
const __fp16 *restrict vtcm_src,
uint32_t start_row,
uint32_t n_rows,
uint32_t n_cols,
uint32_t dst_stride,
uint32_t src2_stride,
uint32_t dst_cols
) {
assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0);
const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS;
@@ -727,6 +737,7 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile
const __fp16 *row_base = vtcm_src + r0 * tile_row_stride;
float *output_row_base = dst + r * dst_stride; // global memory row base for row r (and r+1)
const float *src2_row_base = src2 ? (src2 + r * src2_stride) : NULL;
#pragma unroll(4)
for (size_t c = 0; c < limit_c_aligned; c += HTP_MM_HMX_TILE_N_COLS) {
@@ -738,9 +749,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
HVX_Vector *pv_out0 = (HVX_Vector *) (output_row_base + c + 0);
HVX_Vector *pv_out1 = (HVX_Vector *) (output_row_base + c + dst_stride);
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
}
*pv_out0 = v_out0;
if (r + 1 < n_rows) {
*pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
}
*pv_out1 = v_out1;
}
}
@@ -752,9 +774,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
HVX_Vector v = ((const HVX_Vector *) tile)[r1];
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)));
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
}
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), v_out0);
if (r + 1 < n_rows) {
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)));
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
if (src2_row_base) {
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
}
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), v_out1);
}
}
}
@@ -763,11 +796,13 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
typedef struct {
const __fp16 *vtcm_src;
float *dst;
const float *src2;
uint32_t n_tasks;
uint32_t n_tot_chunks;
uint32_t n_chunks_per_task;
uint32_t n_cols;
uint32_t dst_stride; // DDR row stride
uint32_t src2_stride; // DDR row stride for residual
uint32_t dst_cols; // Actual output columns
struct htp_thread_trace * traces;
} output_transfer_task_state_t;
+35 -35
View File
@@ -42,14 +42,14 @@ static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VL
// Full range: start_row=0, end_row=n_cols.
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const __fp16 * restrict vtcm_src,
int n_cols,
int k,
int src_stride,
int start_row,
int end_row) {
uint32_t n_cols,
uint32_t k,
size_t src_stride,
uint32_t start_row,
uint32_t end_row) {
assert(k % HMX_FP16_TILE_N_COLS == 0);
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const uint32_t n_k_tiles = k / HMX_FP16_TILE_N_COLS;
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
@@ -65,14 +65,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
if (pair_scatter) {
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
const uint32_t c_step = 2 * HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
const uint32_t n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
for (uint32_t r = start_row; r < end_row; r += 2) {
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
@@ -86,7 +86,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
assert(c_byte_step % 128 == 0);
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
@@ -95,7 +95,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
@@ -105,14 +105,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
// Fallback: scatter one K-tile per call (region 2047, masked).
const int c_step = HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
const int n_c_iters = k / c_step;
const uint32_t c_step = HMX_FP16_TILE_N_COLS;
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
const uint32_t n_c_iters = k / c_step;
for (int r = start_row; r < end_row; r += 2) {
const int ct = r / HMX_FP16_TILE_N_ROWS;
const int local_r = r % HMX_FP16_TILE_N_ROWS;
for (uint32_t r = start_row; r < end_row; r += 2) {
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
@@ -122,7 +122,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
@@ -131,7 +131,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
for (uint32_t i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
@@ -148,24 +148,24 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
// Full range: start_row=0, end_row=n_rows.
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
const __fp16 * restrict src,
int n_rows,
int head_dim,
int src_stride,
int n_row_tiles,
int start_row,
int end_row) {
uint32_t n_rows,
uint32_t head_dim,
size_t src_stride,
uint32_t n_row_tiles,
uint32_t start_row,
uint32_t end_row) {
__builtin_assume(head_dim > 0);
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
for (int r = start_row; r < end_row; r += 2) {
for (uint32_t r = start_row; r < end_row; r += 2) {
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
// Row-pair invariants hoisted out of the c loop.
const int r0 = r / HMX_FP16_TILE_N_ROWS;
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
const uint32_t r0 = r / HMX_FP16_TILE_N_ROWS;
const uint32_t r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
@@ -174,7 +174,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
const size_t tb_step = 2 * tile_stride_elms;
if (pv_in1) {
for (int c = 0; c < head_dim; c += 64) {
for (uint32_t c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_Vector v1 = *pv_in1++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
@@ -185,7 +185,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
}
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int c = 0; c < head_dim; c += 64) {
for (uint32_t c = 0; c < head_dim; c += 64) {
HVX_Vector v0 = *pv_in0++;
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
+6
View File
@@ -60,6 +60,7 @@ enum htp_op_code {
HTP_OP_MUL_MAT_ID,
HTP_OP_MUL_MAT_QKV,
HTP_OP_MUL_MAT_FFN,
HTP_OP_MUL_MAT_ADD,
HTP_OP_RMS_NORM,
HTP_OP_RMS_NORM_MUL,
HTP_OP_UNARY_SILU,
@@ -175,6 +176,11 @@ enum htp_trace_event_id {
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
HTP_TRACE_EVT_HVX_W_PREP = 24,
HTP_TRACE_EVT_HVX_O_PROC = 25,
HTP_TRACE_EVT_HVX_FA_QK = 26,
HTP_TRACE_EVT_HVX_FA_SFM = 27,
HTP_TRACE_EVT_HVX_FA_Q_PREP = 28,
HTP_TRACE_EVT_HVX_FA_K_PREP = 29,
HTP_TRACE_EVT_HVX_FA_V_PREP = 30,
HTP_TRACE_EVT_HMX_COMP = 40,
};
+1 -12
View File
@@ -134,16 +134,7 @@ static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1)
}
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
#if __HVX_ARCH__ < 79
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
v = Q6_V_vmux_QVV(nan, neg_inf, v);
#endif
return v;
return Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
}
#if __HVX_ARCH__ >= 79
@@ -170,8 +161,6 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
}
#endif
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
// This looks complicated.
// Ideally should just be Q6_Vh_equals_Vhf(vin)
+39
View File
@@ -16,6 +16,7 @@
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
#define EXP_LOG2E_F 1.44269504f
#define EXP_ONE (0x3f800000) // 1.0
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
@@ -213,4 +214,42 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
}
}
static inline HVX_Vector hvx_vec_exp2_f16(HVX_Vector x_v) {
const HVX_Vector zero_v = Q6_V_vzero();
const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5
// Clamp input to prevent integer underflow in FP16-to-INT16 conversion
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
x_v = Q6_Vhf_vmax_VhfVhf(v_clamp_min, x_v);
// k = round_toward_neg_inf(x); f = (float)k; frac = x - f
HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v));
HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16
HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16
HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16
// Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0
HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0
// Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k
y = Q6_Vhf_equals_Vqf16(y);
HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11);
y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp);
HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp);
y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10);
return Q6_V_vmux_QVV(q_underflow, zero_v, y);
}
#endif /* HVX_EXP_H */
+232
View File
@@ -0,0 +1,232 @@
#ifndef HVX_FA_KERNELS_H
#define HVX_FA_KERNELS_H
#include <assert.h>
#include <math.h>
#include "hvx-utils.h"
// Little inner kernels for HVX
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// This is a bit of a hack because the compiler is struggling to properly inline
// the default hvx_vec_f32_to_f16 with output into the local array.
static __attribute__((unused)) __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
{
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, rsum);
}
static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t nvec,
const size_t nloe) {
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
uint32_t i = 0;
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_Vector x2_hf = vx2[i];
HVX_Vector x3_hf = vx3[i];
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
if (nloe) {
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
return hvx_vec_reduce_sum_f32x4(rsum0123);
}
static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
const uint8_t * restrict x,
const size_t stride_x,
const size_t n,
float s) {
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
const size_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector sums = Q6_V_vzero();
const size_t stride_x_4 = stride_x * 4;
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
x += stride_x_4;
}
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
}
// MAD: y (F32) += x (F16) * s (F16)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, uint32_t n) {
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(*s);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
}
if (nloe) {
HVX_VectorPair xy_p = vy_p[i];
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
HVX_Vector xy = Q6_V_lo_W(xy_p);
i = 2 * i; // index for vy
if (nloe >= VLEN_FP32) {
vy[i] = xy;
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
const __fp16 * restrict s0, const __fp16 * restrict s1, uint32_t n) {
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
HVX_Vector * restrict vy = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S0 = hvx_vec_splat_f16(*s0);
HVX_Vector S1 = hvx_vec_splat_f16(*s1);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; ++i) {
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
}
if (nloe) {
HVX_VectorPair xy_p = vy_p[i];
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
HVX_Vector xy = Q6_V_lo_W(xy_p);
i = 2 * i; // index for vy
if (nloe >= VLEN_FP32) {
vy[i] = xy;
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
}
if (nloe) {
hvx_vec_store_a(&vy[i], nloe * 4, xy);
}
}
}
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t n, HVX_Vector vs) {
assert((size_t) dst % 128 == 0);
assert((size_t) src % 128 == 0);
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
const uint32_t nvec = n / VLEN_FP32;
const uint32_t nloe = n % VLEN_FP32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
}
if (nloe) {
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
}
}
#endif /* HVX_FA_KERNELS_H */
+512 -25
View File
@@ -256,7 +256,7 @@ static inline void quantize_f16_f16_flat_kernel(
// Dot kernels that consume flat (non-tiled) activations
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -312,10 +312,14 @@ static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -397,11 +401,19 @@ static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -464,10 +476,14 @@ static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -561,11 +577,19 @@ static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -620,10 +644,14 @@ static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const v
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -704,11 +732,19 @@ static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -765,10 +801,14 @@ static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -851,11 +891,19 @@ static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -921,10 +969,14 @@ static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -1019,6 +1071,441 @@ static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static inline void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
rsum = HVX_OP_ADD_F32(rsum, prod);
}
*s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector rsum0 = Q6_V_vzero();
HVX_Vector rsum1 = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_sf = y[i];
HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP32;
uint32_t nloe = n % VLEN_FP32;
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_sf = x0[i];
HVX_Vector r1_sf = x1[i];
HVX_Vector c0_sf = y0[i];
HVX_Vector c1_sf = y1[i];
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
}
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(s0, 8, r0_r1_c0_sum);
hvx_vec_store_u(s1, 8, r0_r1_c1_sum);
}
static inline void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
uint32_t nloe = n % VLEN_FP32; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
if (nloe) {
HVX_Vector x_sf = vx[i];
HVX_Vector y_sf = vy[i];
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
x_sf = Q6_V_vand_QV(bmask, x_sf);
y_sf = Q6_V_vand_QV(bmask, y_sf);
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
}
rsum = hvx_vec_reduce_sum_f32(rsum);
hvx_vec_store_u(&s[0], 4, rsum);
}
#undef HVX_OP_ADD_F32
#undef HVX_OP_MUL_F32
static inline void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_VectorPair rsum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
}
static inline void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
HVX_VectorPair rsum0_p = Q6_W_vzero();
HVX_VectorPair rsum1_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = y[i];
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
}
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
hvx_vec_store_u(s0, 8, rsum);
}
static inline void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1) {
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
uint32_t nvec = n / VLEN_FP16;
uint32_t nloe = n % VLEN_FP16;
// Row sums (sf) - 4 accumulators for 2x2 tile
HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
HVX_Vector r0_hf = x0[i];
HVX_Vector r1_hf = x1[i];
HVX_Vector c0_hf = y0[i];
HVX_Vector c1_hf = y1[i];
// Compute 4 dot products: r0xc0, r0xc1, r1xc0, r1xc1
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
}
HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
static inline void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_UVector * restrict x = (const HVX_UVector *) vx;
const HVX_UVector * restrict y = (const HVX_UVector *) vy;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vzero();
HVX_Vector rsum = Q6_V_vzero();
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
// Convert into fp32 and reduce
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum);
}
static inline void hvx_tensor_add_f32_grid(
const struct htp_tensor * restrict dst,
const struct htp_tensor * restrict src2,
uint32_t start_row,
uint32_t end_row,
uint32_t start_col,
uint32_t end_col,
const struct fastdiv_values * div_ne11_12,
const struct fastdiv_values * div_ne11
) {
if (start_row >= end_row || start_col >= end_col) return;
const uint32_t nb1 = dst->nb[1]; // row stride in bytes
const uint32_t ne11 = dst->ne[1];
const uint32_t ne12 = dst->ne[2];
const uint32_t ne11_12 = ne11 * ne12;
const bool is_broadcast1 = (src2->ne[1] == 1);
const bool is_broadcast2 = (src2->ne[2] == 1);
const bool is_broadcast3 = (src2->ne[3] == 1);
for (uint32_t r = start_row; r < end_row; r++) {
float * dst_row = (float *) ((uint8_t *) dst->data + r * nb1);
uint32_t i13 = fastdiv(r, div_ne11_12);
uint32_t i12 = fastdiv(r - i13 * ne11_12, div_ne11);
uint32_t i11 = r - i13 * ne11_12 - i12 * ne11;
uint32_t i23 = is_broadcast3 ? 0 : i13;
uint32_t i22 = is_broadcast2 ? 0 : i12;
uint32_t i21 = is_broadcast1 ? 0 : i11;
const float * src2_row = (const float *) ((const uint8_t *) src2->data +
i21 * src2->nb[1] + i22 * src2->nb[2] + i23 * src2->nb[3]);
float * dst_ptr = &dst_row[start_col];
const float * src2_ptr = &src2_row[start_col];
int remaining = end_col - start_col;
while (remaining >= 32) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vmemu(dst_ptr) = hvx_vec_add_f32_f32(v_out, v_z);
dst_ptr += 32;
src2_ptr += 32;
remaining -= 32;
}
if (remaining > 0) {
HVX_Vector v_out = hvx_vmemu(dst_ptr);
HVX_Vector v_z = hvx_vmemu(src2_ptr);
hvx_vec_store_u(dst_ptr, remaining * sizeof(float), hvx_vec_add_f32_f32(v_out, v_z));
}
}
}
@@ -378,7 +378,7 @@ static inline HVX_VectorPair accum_q8_0_32x2(
return Q6_W_vcombine_VV(v_sum1, v_sum0);
}
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -401,10 +401,14 @@ static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -484,11 +488,19 @@ static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -519,10 +531,14 @@ static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -637,11 +653,19 @@ static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -663,10 +687,14 @@ static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -745,11 +773,19 @@ static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -773,10 +809,14 @@ static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
}
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -857,11 +897,19 @@ static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, floa
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
}
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y_q = vy;
@@ -896,10 +944,14 @@ static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
if (sz) {
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
} else {
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
}
}
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
const uint8_t * restrict tile_ptr = vx;
const uint8_t * restrict y0_q = vy0;
const uint8_t * restrict y1_q = vy1;
@@ -1013,8 +1065,16 @@ static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, floa
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
if (sz0) {
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
} else {
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
}
if (sz1) {
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
} else {
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
}
}
static inline void quantize_f32_q8_0_tiled_kernel(
+39
View File
@@ -3,6 +3,7 @@
#include "hvx-base.h"
#include "hvx-inverse.h"
#include "hvx-exp.h"
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
@@ -139,4 +140,42 @@ static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restr
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline HVX_Vector hvx_vec_fast_sigmoid_f16(HVX_Vector x_v) {
const HVX_Vector v_one = hvx_vec_splat_f16(1.0f);
const HVX_Vector v_neg_log2e = hvx_vec_splat_f16(-EXP_LOG2E_F);
const HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
// Compute absolute value of x_v
HVX_Vector abs_x = Q6_V_vand_VV(x_v, em_mask);
// Compute u = -abs_x * log2(e) <= 0.
HVX_Vector u = hvx_vec_mul_f16_f16(abs_x, v_neg_log2e);
// Clamp input to prevent underflow in exp2
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
u = Q6_Vhf_vmax_VhfVhf(v_clamp_min, u);
HVX_Vector exp_val = hvx_vec_exp2_f16(u);
HVX_Vector denom = hvx_vec_add_f16_f16(v_one, exp_val);
HVX_Vector sig_abs = hvx_vec_inverse_f16(denom);
// check if x_v < 0 (using integer comparison on absolute value)
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(abs_x, x_v);
// If x_v < 0, return 1.0f - sig_abs
HVX_Vector sig_neg = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(v_one, sig_abs));
return Q6_V_vmux_QVV(is_neg, sig_neg, sig_abs);
}
static inline HVX_Vector hvx_vec_tanh_f16(HVX_Vector x) {
// tanh(x) = 2 * sigmoid(2x) - 1
const HVX_Vector v_two = hvx_vec_splat_f16(2.0f);
HVX_Vector x2 = hvx_vec_mul_f16_f16(x, v_two);
HVX_Vector sig2x = hvx_vec_fast_sigmoid_f16(x2);
const HVX_Vector v_neg_one = hvx_vec_splat_f16(-1.0f);
return hvx_vec_add_f16_f16(hvx_vec_mul_f16_f16(sig2x, v_two), v_neg_one);
}
#endif /* HVX_SIGMOID_H */
+1
View File
@@ -575,6 +575,7 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
static int execute_op(struct htp_ops_context * octx) {
switch (octx->op) {
case HTP_OP_MUL_MAT:
case HTP_OP_MUL_MAT_ADD:
return op_matmul(octx);
case HTP_OP_MUL_MAT_ID:
File diff suppressed because it is too large Load Diff
+19 -32
View File
@@ -392,56 +392,49 @@ static inline size_t htp_mm_hvx_get_vtcm_sizes(
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
if (dst_size_per_thread < quant_scratch_size_per_thread) {
dst_size_per_thread = quant_scratch_size_per_thread;
}
vtcm_dst_size = dst_size_per_thread * n_threads;
break;
}
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
if (dst_size_per_thread < quant_scratch_size_per_thread) {
dst_size_per_thread = quant_scratch_size_per_thread;
}
vtcm_dst_size = dst_size_per_thread * n_threads;
break;
}
default:
@@ -463,7 +456,8 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
size_t src0_row_size, // nb01
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out
size_t * vtcm_src1_size_out,
size_t * vtcm_dst_size_out
) {
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
@@ -476,29 +470,22 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (src0_sz_per_thread < src1_row_size_padded) {
src0_sz_per_thread = src1_row_size_padded;
}
if (is_repack) {
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
const uint32_t n_k_tiles = ne10 / 32;
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
}
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
const size_t vtcm_dst_size = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * n_threads;
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = src1_sz;
*vtcm_dst_size_out = vtcm_dst_size;
return vtcm_src0_size + src1_sz;
return vtcm_src0_size + src1_sz + vtcm_dst_size;
}
#ifdef __cplusplus
+5
View File
@@ -78,6 +78,8 @@ set(GGML_OPENCL_KERNELS
mul_mv_f16_f32_l4
mul_mv_f16_f32
mul_mv_f32_f32
mul_mv_q1_0_f32
mul_mv_q1_0_f32_flat
mul_mv_q4_0_f32
mul_mv_q4_0_f32_v
mul_mv_q4_0_f32_8x_flat
@@ -128,6 +130,7 @@ set(GGML_OPENCL_KERNELS
moe_sort_by_expert
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q1_0_f32_l4_lm
mul_mm_q4_0_f32_l4_lm
mul_mm_q4_1_f32_l4_lm
mul_mm_q5_0_f32_l4_lm
@@ -137,6 +140,8 @@ set(GGML_OPENCL_KERNELS
mul_mm_q4_k_f32_l4_lm
mul_mm_q5_k_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
gemv_noshuffle_q1_0_f32
gemm_noshuffle_q1_0_f32
gemv_noshuffle_q4_0_f32
gemv_noshuffle_q4_0_f32_spec
gemm_noshuffle_q4_0_f32
+615
View File
@@ -631,6 +631,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mm_f16_f32_kqv;
cl_kernel kernel_mul_mm_f16_f32_kq;
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
cl_kernel kernel_convert_block_q1_0, kernel_restore_block_q1_0;
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns;
cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
@@ -670,6 +671,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_convert_block_iq4_nl, kernel_restore_block_iq4_nl;
cl_kernel kernel_convert_block_iq4_nl_noshuffle;
cl_kernel kernel_restore_block_iq4_nl_noshuffle;
cl_kernel kernel_mul_mv_q1_0_f32, kernel_mul_mv_q1_0_f32_flat;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
cl_kernel kernel_mul_mv_q4_1_f32;
cl_kernel kernel_mul_mv_q4_1_f32_flat;
@@ -733,6 +735,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
cl_kernel kernel_mul_mm_q1_0_f32_l4_lm;
cl_kernel kernel_mul_mm_q4_0_f32_l4_lm;
cl_kernel kernel_mul_mm_q4_1_f32_l4_lm;
cl_kernel kernel_mul_mm_q5_0_f32_l4_lm;
@@ -890,6 +893,8 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gemm_noshuffle_q4_1_f32;
cl_kernel kernel_gemm_noshuffle_q8_0_f32;
cl_kernel kernel_gemv_noshuffle_q8_0_f32;
cl_kernel kernel_gemm_noshuffle_q1_0_f32;
cl_kernel kernel_gemv_noshuffle_q1_0_f32;
cl_kernel kernel_gemv_noshuffle_q4_k_f32;
cl_kernel kernel_gemm_noshuffle_q4_k_f32;
cl_kernel kernel_gemv_noshuffle_q6_K_f32;
@@ -1151,6 +1156,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
backend_ctx->program_cvt =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_convert_block_q1_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q1_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q1_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q1_0", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
@@ -1685,6 +1692,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// mul_mv_q1_0_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q1_0_f32.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q1_0_f32.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_q1_0_f32 = clCreateKernel(prog, "kernel_mul_mv_q1_0_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_q1_0_f32_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q1_0_f32_flat.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q1_0_f32_flat.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_q1_0_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q1_0_f32_flat", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_iq4_nl_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1990,6 +2031,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// mul_mm_q1_0_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mm_q1_0_f32_l4_lm.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mm_q1_0_f32_l4_lm.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mm_q1_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q1_0_f32_l4_lm", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mm_iq4_nl_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2939,6 +2997,44 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
GGML_LOG_CONT(".");
}
// gemm_noshuffle_q1_0_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "gemm_noshuffle_q1_0_f32.cl.h"
};
#else
const std::string kernel_src = read_file("gemm_noshuffle_q1_0_f32.cl");
#endif
cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q1_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q1_0_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemv_noshuffle_q1_0_f32
{
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
" -cl-mad-enable "
" -DSIMDGROUP_WIDTH=" +
std::to_string(backend_ctx->adreno_wave_size);
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src_CL_gemv_general {
#include "gemv_noshuffle_q1_0_f32.cl.h"
};
#else
const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_q1_0_f32.cl");
#endif
cl_program prog = build_program_from_source(
backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts);
CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q1_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q1_0_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// gemv_noshuffle_general
{
std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
@@ -4829,6 +4925,39 @@ struct ggml_tensor_extra_cl {
}
};
struct ggml_tensor_extra_cl_q1_0 {
cl_mem q = nullptr;
cl_mem q_img = nullptr;
cl_mem d = nullptr;
cl_mem d_img = nullptr;
size_t size_q = 0;
size_t size_d = 0;
~ggml_tensor_extra_cl_q1_0() {
reset();
}
void reset() {
// q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
// They must be properly released so that the original buffer can be
// properly released to avoid memory leak.
if (q != nullptr) {
CL_CHECK(clReleaseMemObject(q));
q = nullptr;
}
if (d != nullptr) {
CL_CHECK(clReleaseMemObject(d));
d = nullptr;
}
q_img = nullptr;
d_img = nullptr;
size_q = 0;
size_d = 0;
}
};
// Additional tensor extra structs for quantized tensors.
// These tensors are loaded from files and should not be allocated in scratch --
// they should always be allocated from the pool. Hence, they do not have an
@@ -5732,6 +5861,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return true;
} else if (op->src[0]->type == GGML_TYPE_F32) {
return op->src[1]->type == GGML_TYPE_F32;
} else if (op->src[0]->type == GGML_TYPE_Q1_0) {
return op->src[1]->type == GGML_TYPE_F32;
} else if (op->src[0]->type == GGML_TYPE_Q4_0) {
// Non-contig src0 routes through on-device dequant-to-f16.
return op->src[1]->type == GGML_TYPE_F32;
@@ -5988,6 +6119,12 @@ struct ggml_backend_opencl_buffer_context {
for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {
delete e;
}
for (ggml_tensor_extra_cl_q1_0 * e : temp_tensor_extras_q1_0) {
delete e;
}
for (ggml_tensor_extra_cl_q1_0 * e : temp_tensor_extras_q1_0_in_use) {
delete e;
}
for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl) {
delete e;
}
@@ -6029,6 +6166,21 @@ struct ggml_backend_opencl_buffer_context {
return extra;
}
ggml_tensor_extra_cl_q1_0 * ggml_opencl_alloc_temp_tensor_extra_q1_0() {
ggml_tensor_extra_cl_q1_0 * extra;
if (temp_tensor_extras_q1_0.empty()) {
extra = new ggml_tensor_extra_cl_q1_0();
} else {
extra = temp_tensor_extras_q1_0.back();
temp_tensor_extras_q1_0.pop_back();
}
temp_tensor_extras_q1_0_in_use.push_back(extra);
extra->reset();
return extra;
}
ggml_tensor_extra_cl_q4_0 * ggml_opencl_alloc_temp_tensor_extra_q4_0() {
ggml_tensor_extra_cl_q4_0 * extra;
if (temp_tensor_extras_q4_0.empty()) {
@@ -6185,6 +6337,11 @@ struct ggml_backend_opencl_buffer_context {
}
temp_tensor_extras_in_use.clear();
for (ggml_tensor_extra_cl_q1_0 * e : temp_tensor_extras_q1_0_in_use) {
temp_tensor_extras_q1_0.push_back(e);
}
temp_tensor_extras_q1_0_in_use.clear();
for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {
temp_tensor_extras_q4_0.push_back(e);
}
@@ -6246,6 +6403,8 @@ struct ggml_backend_opencl_buffer_context {
// for reuse.
std::vector<ggml_tensor_extra_cl *> temp_tensor_extras;
std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use;
std::vector<ggml_tensor_extra_cl_q1_0 *> temp_tensor_extras_q1_0;
std::vector<ggml_tensor_extra_cl_q1_0 *> temp_tensor_extras_q1_0_in_use;
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0;
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use;
std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1;
@@ -6353,6 +6512,82 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
cl_command_queue queue = backend_ctx->queue;
#ifdef GGML_OPENCL_SOA_Q
if (tensor->type == GGML_TYPE_Q1_0) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
// Allocate the new extra and create aliases from the original.
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
ggml_tensor_extra_cl_q1_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q1_0();
// q1_0 block = ggml_half d + (QK1_0/8) quant bytes = 2 + 16 = 18 bytes
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)/8);
GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
CL_CHECK(clEnqueueWriteBuffer(
queue, data_device, CL_TRUE, 0,
ggml_nbytes(tensor), data, 0, NULL, NULL));
// The original tensor memory is divided into scales and quants, i.e.,
// we first store scales, then quants.
cl_buffer_region region;
// Create subbuffer for scales.
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
region.size = size_d;
extra->d = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
auto previous_origin = region.origin;
// Create subbuffer for quants.
region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
region.size = size_q;
extra->q = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_convert_block_q1_0;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {64, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clReleaseMemObject(data_device));
tensor->extra = extra;
// q is uint32 (32 sign bits each); d is one half per 128-block.
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (enable_adreno_trans_weight(backend_ctx, tensor)) {
int M = tensor->ne[1]; // ne01
int K = tensor->ne[0]; // ne00
GGML_ASSERT(K % 128 == 0);
GGML_ASSERT(M % 4 == 0);
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
transpose_2d_as_32b(backend_ctx, extra->q, extra->q, size_q, K/32, M);
transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/128, M);
} // end transpose
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
return;
}
// We separate the quantized bits and scale from block_q4_0 by using an
// additional kernel, where each thread handles a block. We first read the
// original weights into a temporary buffer, then create two separate
@@ -7743,6 +7978,63 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
sync_with_other_backends(backend_ctx);
#ifdef GGML_OPENCL_SOA_Q
if (tensor->type == GGML_TYPE_Q1_0) {
ggml_tensor_extra_cl_q1_0 * extra = (ggml_tensor_extra_cl_q1_0 *)tensor->extra;
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
if (enable_adreno_trans_weight(backend_ctx, tensor)) {
ggml_cl_buffer buf_trans_q;
ggml_cl_buffer buf_trans_d;
ggml_cl_buffer buf_unpacked;
int M = tensor->ne[1];
int K = tensor->ne[0];
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)/8);
buf_trans_q.allocate(backend_ctx->context, size_q);
buf_trans_d.allocate(backend_ctx->context, size_d);
buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor));
transpose_2d_as_32b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/32);
transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/128);
cl_kernel kernel = backend_ctx->kernel_restore_block_q1_0;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_unpacked.buffer));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL));
return;
}
#endif
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_restore_block_q1_0;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset, size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
}
// In end-to-end runs, get_tensor is usually used to get back the logits,
// where we can simply do clEnqueueReadBuffer since they are f32.
// However, in test-backend-ops, the GPU graph is copied to the CPU backend,
@@ -13437,6 +13729,203 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten
CL_CHECK(clReleaseMemObject(D_sub_buffer));
}
static void ggml_cl_mul_mat_q1_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
GGML_ASSERT(src0->type == GGML_TYPE_Q1_0);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl_q1_0 * extra0_q1_0 = (ggml_tensor_extra_cl_q1_0 *)src0->extra;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
GGML_ASSERT(src1->view_offs == 0);
GGML_ASSERT(dst->view_offs == 0);
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne10 = src1->ne[0];
const int ne12 = src1->ne[2];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
GGML_ASSERT(ne00 == ne10);
GGML_ASSERT((ne00 % 128) == 0);
GGML_ASSERT(ne0 == ne01);
cl_context context = backend_ctx->context;
cl_kernel kernel;
cl_int err;
cl_image_format img_fmt;
cl_image_desc img_desc;
cl_buffer_region region;
int M = ne01;
int N = ne1;
int K = ne00;
if (ne1 == 1) {
cl_mem q_img = nullptr;
cl_mem b_sub_buf = nullptr;
cl_mem b_img = nullptr;
// image for q (uint32: each texel packs 32 sign bits)
img_fmt = { CL_R, CL_UNSIGNED_INT32};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = M * K / 32;
img_desc.buffer = extra0_q1_0->q;
CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// create a sub_buffer for B
region.origin = offset1;
region.size = K * N * sizeof(float);
CL_CHECK((b_sub_buf = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// image for activations
img_fmt = {CL_RGBA, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = K * N / 4;
img_desc.buffer = b_sub_buf;
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
kernel = backend_ctx->kernel_gemv_noshuffle_q1_0_f32;
int r2 = 1;
int r3 = 1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q1_0->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &extra1->offset));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &extrad->offset));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
size_t wavesize = backend_ctx->adreno_wave_size;
size_t local_work_size[] = { wavesize, 4, 1 };
size_t global_work_size[] = { CEIL_DIV(M, wavesize)*wavesize, 4, 1 };
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
CL_CHECK(clReleaseMemObject(q_img));
CL_CHECK(clReleaseMemObject(b_img));
CL_CHECK(clReleaseMemObject(b_sub_buf));
} else {
cl_mem b_sub_buf = nullptr;
cl_mem b_sub_buf_trans = nullptr;
cl_mem b_img = nullptr;
cl_mem b_img_trans = nullptr;
// subbuffer for activations
region.origin = offset1;
region.size = K * N * sizeof(float);
CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// image for activations
img_fmt = {CL_RGBA, CL_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = K * N / 4;
img_desc.buffer = b_sub_buf;
CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
// pad N to multiple of 8
int extra_elements = N % 8;
int padding = 0;
if (extra_elements > 0){
padding = 8 - extra_elements;
}
// subbuffer for transposed activations
region.origin = 0;
region.size = K * (N + padding) * sizeof(float)/2;
backend_ctx->prealloc_act_trans.allocate(context, region.size);
CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &err), err));
// image for transposed activations
img_fmt = {CL_RGBA, CL_HALF_FLOAT};
memset(&img_desc, 0, sizeof(img_desc));
img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
img_desc.image_width = K * (N + padding) / 4;
img_desc.buffer = b_sub_buf_trans;
CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err));
// transpose activations
int height_B = N/4;
if (height_B == 0) {
height_B = 1;
}
int width_B = K/4;
int padded_height_B = (N + padding)/4;
kernel = backend_ctx->kernel_transpose_32_16;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B));
size_t local_work_size_t[2] = { 1, 16 };
size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B };
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst);
// gemm
kernel = backend_ctx->kernel_gemm_noshuffle_q1_0_f32;
int padded_N = N + padding;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q1_0->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q1_0->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &K));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &M));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &padded_N));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &N));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = { (size_t)CEIL_DIV(N, 8), (size_t)CEIL_DIV(M, 4), 1 };
size_t local_work_size[] = { 2, 128, 1 };
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
CL_CHECK(clReleaseMemObject(b_img_trans));
CL_CHECK(clReleaseMemObject(b_sub_buf_trans));
CL_CHECK(clReleaseMemObject(b_img));
CL_CHECK(clReleaseMemObject(b_sub_buf));
}
#else
GGML_UNUSED(backend);
GGML_UNUSED(src0);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
#endif
}
static void ggml_cl_mul_mat_q4_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
GGML_ASSERT(src0);
@@ -15311,6 +15800,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
// view->extra stays pre-SoA; cast to the SoA struct would SIGSEGV.
// Follow view_src to reach the real SoA extra.
const ggml_tensor * soa0_src = src0->view_src != nullptr ? src0->view_src : src0;
ggml_tensor_extra_cl_q1_0 * extra0_q1_0 = (ggml_tensor_extra_cl_q1_0 *)src0->extra;
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)soa0_src->extra;
ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)soa0_src->extra;
ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)soa0_src->extra;
@@ -15374,6 +15864,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
// a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that
// limit, so the check is omitted.
// q1_0 x fp32
if (src0t == GGML_TYPE_Q1_0 && src1t == GGML_TYPE_F32 &&
enable_adreno_trans_weight(backend_ctx, src0)) {
ggml_cl_mul_mat_q1_0_f32_adreno(backend, src0, src1, dst);
return;
}
// q4_0 x fp32
if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) {
ggml_cl_mul_mat_q4_0_f32_adreno(backend, src0, src1, dst);
@@ -15577,6 +16074,48 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q1_0: {
if (ne11 < 32) {
break;
}
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
break;
}
kernel = backend_ctx->kernel_mul_mm_q1_0_f32_l4_lm;
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
int batch_stride_a = ne00*ne01;
int batch_stride_b = ne10*ne11;
int batch_stride_d = ne0*ne1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q1_0->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q1_0->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q4_0: {
if (ne11 < 32) {
break;
@@ -16165,6 +16704,81 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3));
break;
case GGML_TYPE_Q1_0: {
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_q1_0_f32_flat;
// nth0 - subgroup size
// nth1 - number of subgroups per workgroup
// ndst - number of output values per workgroup = output per subgroup * number of subgroups
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 2;
ndst = nth1*4;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 2;
ndst = nth1*4;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q1_0->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q1_0->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
#else
kernel = backend_ctx->kernel_mul_mv_q1_0_f32;
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 2;
ndst = nth1*4;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 2;
ndst = nth1*4;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
#endif // GGML_OPENCL_SOA_Q
break;
}
case GGML_TYPE_Q4_0:
// This should have been satisfied.
GGML_ASSERT(ne11 == ne1);
@@ -16879,6 +17493,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 ||
src0t == GGML_TYPE_Q8_0 ||
src0t == GGML_TYPE_Q1_0 ||
src0t == GGML_TYPE_IQ4_NL ||
src0t == GGML_TYPE_Q2_K) {
// Each SIMD group produces N_DST values in the result. Assuming each
+46
View File
@@ -27,6 +27,8 @@
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK1_0 128
#define QR1_0 1
#define QK_K 256
#define K_SCALE_SIZE (3 * QK_K / 64)
#define K_QUANTS_PER_ITERATION 2
@@ -38,6 +40,14 @@ typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q1_0
//------------------------------------------------------------------------------
typedef struct {
half d; // delta
uchar qs[QK1_0/8]; // 1-bit signs (16 bytes)
} block_q1_0;
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
@@ -159,6 +169,42 @@ kernel void kernel_convert_f16_to_bf16(
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q1_0
// Convert block_q1_0 (AOS) to 2 separate arrays (SOA): quant bytes + scales.
// q1_0 bits are stored in natural order (bit j of byte i -> weight 8*i + j)
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q1_0(
global block_q1_0 * src0,
global uchar * dst_q,
global half * dst_d
) {
global block_q1_0 * b = (global block_q1_0 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + (QK1_0/8)*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
*d = b->d;
for (int i = 0; i < QK1_0/8; ++i) {
q[i] = b->qs[i];
}
}
kernel void kernel_restore_block_q1_0(
global uchar * src_q,
global half * src_d,
global block_q1_0 * dst
) {
global block_q1_0 * b = (global block_q1_0 *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + (QK1_0/8)*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
b->d = *d;
for (int i = 0; i < QK1_0/8; ++i) {
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q4_0
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
@@ -0,0 +1,94 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
// each work-item computes a 4 (rows of A / m) x 8 (cols of B / n) output tile.
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_128
#endif
kernel void kernel_gemm_noshuffle_q1_0_f32(
global const uint * src0_q,
global const half * src0_d,
read_only image1d_buffer_t src1,
global float * dst,
int k,
int m,
int n,
int n_no_padding,
ulong offsetd
) {
int n_4 = n >> 2;
int gy = get_global_id(0);
int gx = get_global_id(1);
int gx_2 = gx << 2;
dst = (global float *)((global char*)dst + offsetd);
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
half8 B;
global const uint* wptr = src0_q + gx_2;
global const half* sptr = src0_d + gx_2;
// 32 weights per uint32, 128 weights (one block / one scale) per 4 uint32.
for (int i = 0; i < k; i += 32) {
uint4 pack4 = vload4(0, wptr + (i / 32) * m); // 4 rows, 32 K-values each
half4 scale = vload4(0, sptr + (i / 128) * m); // 4 rows, one scale per 128
for (int j = 0; j < 32; ++j) {
B.s0123 = read_imageh(src1, gy * 2 + (i + j) * n_4);
B.s4567 = read_imageh(src1, gy * 2 + (i + j) * n_4 + 1);
// sign bit -> +-1 (half arithmetic avoids unsigned underflow)
half4 wj = (half4)(
2.0h * (half)((pack4.s0 >> j) & 1u) - 1.0h,
2.0h * (half)((pack4.s1 >> j) & 1u) - 1.0h,
2.0h * (half)((pack4.s2 >> j) & 1u) - 1.0h,
2.0h * (half)((pack4.s3 >> j) & 1u) - 1.0h) * scale;
c0 += B * wj.s0;
c1 += B * wj.s1;
c2 += B * wj.s2;
c3 += B * wj.s3;
}
}
int idx = (gy << 3) * m + (gx << 2);
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
}
}
@@ -0,0 +1,121 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#endif
#define QK1_0 128
#define N_SIMDGROUP 4
#define dequantizeBlockAccum_q1(total, bits, scale, regB, lb) \
total += (2.0f*(float)((bits >> 0) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+0); \
total += (2.0f*(float)((bits >> 1) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+0); \
total += (2.0f*(float)((bits >> 2) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+0); \
total += (2.0f*(float)((bits >> 3) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+0); \
total += (2.0f*(float)((bits >> 4) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+0); \
total += (2.0f*(float)((bits >> 5) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+0); \
total += (2.0f*(float)((bits >> 6) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+0); \
total += (2.0f*(float)((bits >> 7) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+0); \
total += (2.0f*(float)((bits >> 8) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+1); \
total += (2.0f*(float)((bits >> 9) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+1); \
total += (2.0f*(float)((bits >> 10) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+1); \
total += (2.0f*(float)((bits >> 11) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+1); \
total += (2.0f*(float)((bits >> 12) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+1); \
total += (2.0f*(float)((bits >> 13) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+1); \
total += (2.0f*(float)((bits >> 14) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+1); \
total += (2.0f*(float)((bits >> 15) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+1); \
total += (2.0f*(float)((bits >> 16) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+2); \
total += (2.0f*(float)((bits >> 17) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+2); \
total += (2.0f*(float)((bits >> 18) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+2); \
total += (2.0f*(float)((bits >> 19) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+2); \
total += (2.0f*(float)((bits >> 20) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+2); \
total += (2.0f*(float)((bits >> 21) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+2); \
total += (2.0f*(float)((bits >> 22) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+2); \
total += (2.0f*(float)((bits >> 23) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+2); \
total += (2.0f*(float)((bits >> 24) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+3); \
total += (2.0f*(float)((bits >> 25) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+3); \
total += (2.0f*(float)((bits >> 26) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+3); \
total += (2.0f*(float)((bits >> 27) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+3); \
total += (2.0f*(float)((bits >> 28) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+3); \
total += (2.0f*(float)((bits >> 29) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+3); \
total += (2.0f*(float)((bits >> 30) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+3); \
total += (2.0f*(float)((bits >> 31) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+3);
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
__kernel void kernel_gemv_noshuffle_q1_0_f32(
read_only image1d_buffer_t src0_q,
global half * src0_d,
read_only image1d_buffer_t src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3)
{
uint groupId = get_local_id(1);
uint gid = get_global_id(0);
ushort slid = get_sub_group_local_id();
uint K = ne00;
uint M = ne01;
uint LINE_STRIDE_A = M;
uint BLOCK_STRIDE_A = 4 * M;
uint4 regA;
half regS;
float8 regB;
float totalSum = 0.0f;
#pragma unroll 1
for (uint kb = groupId; kb < (K / QK1_0); kb += N_SIMDGROUP) {
regS = src0_d[gid + kb * LINE_STRIDE_A]; // each fiber loads its row's scale
// first 16 fibers load 8 B values each -> 128 activations for this block
if (slid < 16) {
regB.s0123 = read_imagef(src1, (slid * 2 + kb * 32));
regB.s4567 = read_imagef(src1, (1 + slid * 2 + kb * 32));
}
// load this row's 4 uint32 (128 sign bits)
regA.s0 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
regA.s1 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
regA.s2 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
regA.s3 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
float scale = (float)regS;
dequantizeBlockAccum_q1(totalSum, regA.s0, scale, regB, 0);
dequantizeBlockAccum_q1(totalSum, regA.s1, scale, regB, 4);
dequantizeBlockAccum_q1(totalSum, regA.s2, scale, regB, 8);
dequantizeBlockAccum_q1(totalSum, regA.s3, scale, regB, 12);
}
// reduction in local memory, assumes #wave = N_SIMDGROUP = 4
local float reduceLM[SIMDGROUP_WIDTH * 3];
if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
barrier(CLK_LOCAL_MEM_FENCE);
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
if (groupId == 0) {
dst = (global float*)((global char*)dst + offsetd);
dst[gid] = totalSum;
}
}
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
// LOAD_VEC_A is 8 because one q1_0 quant byte expands to 8 weights along K.
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q1_0_f32_l4_lm(
global uchar * src0_q,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 16; // 16 quant bytes per q1_0 block
float d = (float)src0_d[ib];
uint bits = src0_q[idx];
// use float to avoid unsigned underflow of (2*0 - 1).
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 0) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 1) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 2) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 3) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 4) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 4) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 5) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 5) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 6) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 6) & 1) - 1.0f);
buf_a[(loadr_a * LOAD_VEC_A + 7) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 7) & 1) - 1.0f);
} else {
for (int b = 0; b < LOAD_VEC_A; ++b) {
buf_a[(loadr_a * LOAD_VEC_A + b) * BM + loadc_a + l] = 0.0f;
}
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}
@@ -0,0 +1,141 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK1_0 128
typedef struct {
half d;
uchar qs[QK1_0/8];
} block_q1_0;
#define NB_Q1_0 16
#ifdef INTEL_GPU
#define N_R0_Q1_0 4 // number of rows each subgroup works on
#define N_SG_Q1_0 2 // number of subgroups in a work group
#define N_SIMDWIDTH 16 // subgroup size
#elif defined (ADRENO_GPU)
#define N_R0_Q1_0 4
#define N_SG_Q1_0 2
#define N_SIMDWIDTH 64
#endif
inline float block_q_1_0_dot_y(global block_q1_0 * qb, float sumy, float yl[NB_Q1_0], short il) {
global uchar * qs = qb->qs + il*2;
uint b0 = qs[0];
uint b1 = qs[1];
float acc = 0.f;
acc += yl[ 0]*(float)((b0 >> 0) & 1) + yl[ 1]*(float)((b0 >> 1) & 1);
acc += yl[ 2]*(float)((b0 >> 2) & 1) + yl[ 3]*(float)((b0 >> 3) & 1);
acc += yl[ 4]*(float)((b0 >> 4) & 1) + yl[ 5]*(float)((b0 >> 5) & 1);
acc += yl[ 6]*(float)((b0 >> 6) & 1) + yl[ 7]*(float)((b0 >> 7) & 1);
acc += yl[ 8]*(float)((b1 >> 0) & 1) + yl[ 9]*(float)((b1 >> 1) & 1);
acc += yl[10]*(float)((b1 >> 2) & 1) + yl[11]*(float)((b1 >> 3) & 1);
acc += yl[12]*(float)((b1 >> 4) & 1) + yl[13]*(float)((b1 >> 5) & 1);
acc += yl[14]*(float)((b1 >> 6) & 1) + yl[15]*(float)((b1 >> 7) & 1);
return qb->d * (2.0f*acc - sumy);
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q1_0_f32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
int nb = ne00/QK1_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0*N_SG_Q1_0 + get_sub_group_id()) * N_R0_Q1_0;
uint i12 = im%ne12;
uint i13 = im/ne12;
ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
global float * y = (global float *) (src1 + offset_src1);
// pointers to src0 rows
global block_q1_0 * ax[N_R0_Q1_0];
for (int row = 0; row < N_R0_Q1_0; ++row) {
ulong offset_src0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
ax[row] = (global block_q1_0 *) ((global char *) src0 + offset_src0);
}
float yl[NB_Q1_0];
float sumf[N_R0_Q1_0] = { 0.f };
const short ix = get_sub_group_local_id()/8;
const short il = get_sub_group_local_id()%8;
global float * yb = y + ix*QK1_0 + il*NB_Q1_0;
// each thread handles NB_Q1_0 quants at a time
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
float sumy = 0.f;
for (short i = 0; i < NB_Q1_0; ++i) {
yl[i] = yb[i];
sumy += yb[i];
}
for (short row = 0; row < N_R0_Q1_0; row++) {
sumf[row] += block_q_1_0_dot_y(ax[row] + ib, sumy, yl, il);
}
yb += N_SIMDWIDTH*NB_Q1_0;
}
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
for (int row = 0; row < N_R0_Q1_0; ++row) {
float tot = sub_group_reduce_add(sumf[row]);
if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
dst_f32[first_row + row] = tot;
}
}
}
@@ -0,0 +1,190 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK1_0 128
#define QK1_0_BYTES (QK1_0/8) // 16 quant bytes per block
#define QK1_0_BLK_BYTES (QK1_0_BYTES + 2) // d + qs in original tensor = 18
#define NB_Q1_0 16 // quants handled per thread (two qs bytes)
#ifdef INTEL_GPU
#define N_R0_Q1_0 4 // number of rows each subgroup works on
#define N_SG_Q1_0 2 // number of subgroups in a work group
#define N_SIMDWIDTH 16 // subgroup size
#elif defined (ADRENO_GPU)
#define N_R0_Q1_0 4
#define N_SG_Q1_0 2
#define N_SIMDWIDTH 64
#endif
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q1_0_f32_flat(
global char * src0_q,
global half * src0_d,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
int nb = ne00/QK1_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0*N_SG_Q1_0 + get_sub_group_id()) * N_R0_Q1_0;
uint i12 = im%ne12;
uint i13 = im/ne12;
ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
global float * y = (global float *) (src1 + offset_src1);
// pointers to src0 rows (flat: q bytes + scales)
uint offset_src0_base = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
global uchar * ax0, * ax1, * ax2, * ax3;
global half * ad0, * ad1, * ad2, * ad3;
uint offset_src0;
offset_src0 = (offset_src0_base + 0*nb01) / QK1_0_BLK_BYTES;
ax0 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad0 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
offset_src0 = (offset_src0_base + 1*nb01) / QK1_0_BLK_BYTES;
ax1 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad1 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
offset_src0 = (offset_src0_base + 2*nb01) / QK1_0_BLK_BYTES;
ax2 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad2 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
offset_src0 = (offset_src0_base + 3*nb01) / QK1_0_BLK_BYTES;
ax3 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
ad3 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
const short ix = get_sub_group_local_id()/8;
const short il = get_sub_group_local_id()%8;
global float * yb = y + ix*QK1_0 + il*NB_Q1_0;
float8 yl_lo;
float8 yl_hi;
float4 sumf = 0.f;
// each thread handles NB_Q1_0 = 16 quants (two qs bytes) at a time
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
yl_lo = vload8(0, yb);
yl_hi = vload8(0, yb + 8);
float sumy = yl_lo.s0 + yl_lo.s1 + yl_lo.s2 + yl_lo.s3
+ yl_lo.s4 + yl_lo.s5 + yl_lo.s6 + yl_lo.s7
+ yl_hi.s0 + yl_hi.s1 + yl_hi.s2 + yl_hi.s3
+ yl_hi.s4 + yl_hi.s5 + yl_hi.s6 + yl_hi.s7;
uint b0, b1;
float acc;
b0 = ax0[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax0[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s0 += (float)ad0[ib] * (2.0f*acc - sumy);
b0 = ax1[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax1[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s1 += (float)ad1[ib] * (2.0f*acc - sumy);
b0 = ax2[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax2[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s2 += (float)ad2[ib] * (2.0f*acc - sumy);
b0 = ax3[ib*QK1_0_BYTES + il*2 + 0];
b1 = ax3[ib*QK1_0_BYTES + il*2 + 1];
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
sumf.s3 += (float)ad3[ib] * (2.0f*acc - sumy);
yb += N_SIMDWIDTH*NB_Q1_0;
}
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0),
sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2),
sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) dst_f32[first_row + 0] = tot.s0;
if (first_row + 1 < ne01) dst_f32[first_row + 1] = tot.s1;
if (first_row + 2 < ne01) dst_f32[first_row + 2] = tot.s2;
if (first_row + 3 < ne01) dst_f32[first_row + 3] = tot.s3;
}
}
+153 -39
View File
@@ -1907,6 +1907,38 @@ static bool vk_enable_sync_logger = false;
static uint32_t vk_perf_logger_frequency = 1;
static std::string vk_pipeline_stats_filter;
static uint64_t ggml_vk_get_node_flops(const ggml_tensor * node) {
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->ne[0];
const uint64_t n = node->ne[1];
const uint64_t k = node->src[1]->ne[0];
const uint64_t batch = node->ne[2] * node->ne[3];
return m * n * (k + (k - 1)) * batch;
}
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
const ggml_tensor * knl = node->src[0];
const uint64_t Cout = node->ne[2];
const uint64_t size_K = node->src[1]->ne[2] * knl->ne[0] * knl->ne[1];
const uint64_t size_N = node->ne[3] * node->ne[0] * node->ne[1];
return Cout * size_N * (size_K + (size_K - 1));
}
if (node->op == GGML_OP_CONV_3D) {
const ggml_tensor * knl = node->src[0];
const uint64_t OC = ggml_get_op_params_i32(node, 11);
const uint64_t IC = ggml_get_op_params_i32(node, 9);
const uint64_t size_K = IC * knl->ne[0] * knl->ne[1] * knl->ne[2];
const uint64_t size_N = node->ne[3] / OC * node->ne[0] * node->ne[1] * node->ne[2];
return OC * size_N * (size_K + (size_K - 1));
}
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
const ggml_tensor * q = node->src[0];
const ggml_tensor * k = node->src[1];
const ggml_tensor * v = node->src[2];
return 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
}
return 0;
}
class vk_perf_logger {
public:
void print_timings(bool force = false) {
@@ -1955,7 +1987,7 @@ class vk_perf_logger {
}
std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
*n_flops = 0;
*n_flops = ggml_vk_get_node_flops(node);
std::string fusion_str;
if (fusion_name) {
fusion_str = fusion_name + std::string(" ");
@@ -1982,35 +2014,22 @@ class vk_perf_logger {
if (batch > 1) {
name += " batch=" + std::to_string(batch);
}
name = fusion_str + name;
*n_flops = m * n * (k + (k - 1)) * batch;
return name;
return fusion_str + name;
}
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
std::string name = ggml_op_name(node->op);
ggml_tensor * knl = node->src[0];
uint64_t OW = node->ne[0];
uint64_t OH = node->ne[1];
uint64_t N = node->ne[3];
const ggml_tensor * knl = node->src[0];
uint64_t Cout = node->ne[2];
uint64_t KW = knl->ne[0];
uint64_t KH = knl->ne[1];
uint64_t Cin = node->src[1]->ne[2];
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
uint64_t size_M = Cout;
uint64_t size_K = Cin * KW * KH;
uint64_t size_N = N * OW * OH;
*n_flops = size_M * size_N * (size_K + (size_K - 1));
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
uint64_t size_K = node->src[1]->ne[2] * knl->ne[0] * knl->ne[1];
uint64_t size_N = node->ne[3] * node->ne[0] * node->ne[1];
name += " M=Cout=" + std::to_string(Cout) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
", N=N*OW*OH=" + std::to_string(size_N);
name = fusion_str + name;
return name;
return fusion_str + name;
}
if (node->op == GGML_OP_RMS_NORM) {
std::string name = ggml_op_name(node->op);
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
name = fusion_str + name;
return name;
return fusion_str + name;
}
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
const ggml_tensor * dst = node;
@@ -2026,7 +2045,6 @@ class vk_perf_logger {
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
*n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
return name.str();
}
if (node->op == GGML_OP_TOP_K) {
@@ -2090,7 +2108,7 @@ struct ggml_backend_vk_context {
bool do_add_rms_partials_offset_calculation;
bool do_add_rms_partials;
uint64_t last_total_mul_mat_bytes {};
uint64_t last_total_flops {UINT64_MAX};
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
@@ -2457,6 +2475,85 @@ static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count
return true;
}
// Remove the loop unrolling hint of the matmul shader's BK loop
// and replace it with the dont_unroll hint for better performance on
// hardware like Apple M1/M2.
// Assumes 1. code comes from mul_mm.comp 2. the K-tile loop has no loop
// control hint and 3. the BK loop is the last loop nested directly inside
// the K-tile loop.
// Returns true when the input was modified; returns false otherwise
// without touching `out`.
static bool ggml_vk_roll_bk_loop(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) {
if (word_count < 5) {
return false;
}
struct vk_spv_loop {
size_t header;
size_t end;
uint32_t control;
};
std::vector<vk_spv_loop> loops;
// Collect a list of all loops in the module.
for (size_t pos = 5; pos < word_count; ) {
const uint32_t wc = code[pos] >> spv::WordCountShift;
const uint32_t op = code[pos] & spv::OpCodeMask;
if (wc == 0 || pos + wc > word_count) {
return false;
}
if (op == spv::OpLoopMerge && wc >= 4) { loops.push_back({ pos, 0, code[pos + 3] }); }
if (op == spv::OpLabel && wc >= 2) {
for (auto & l : loops) {
if (l.end == 0 && code[l.header + 1] == code[pos + 1]) { l.end = pos; }
}
}
pos += wc;
}
auto encloses = [](const vk_spv_loop & a, const vk_spv_loop & b) {
return a.header < b.header && b.header < a.end;
};
// Find the BK loop.
const vk_spv_loop * bk = nullptr;
for (const auto & h : loops) {
if (h.control != spv::LoopControlUnrollMask) {
continue;
}
const vk_spv_loop * parent = nullptr;
bool has_child = false;
for (const auto & g : loops) {
if (encloses(g, h) && (!parent || g.header > parent->header)) {
parent = &g;
}
if (encloses(h, g)) {
has_child = true;
}
}
// BK loop should be the last loop nested inside the loop with no hint
// and have at least one child loop.
if (parent &&
parent->control == spv::LoopControlMaskNone &&
has_child &&
(!bk || h.header > bk->header)) {
bk = &h;
}
}
if (!bk) {
return false;
}
// set DontUnroll instead of Unroll
out.assign(code, code + word_count);
out[bk->header + 3] = spv::LoopControlDontUnrollMask;
return true;
}
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
@@ -2540,6 +2637,22 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
}
#endif
#if VK_HEADER_VERSION >= 287
// Roll the mul_mm BK loop on Asahi Linux. Skip bf16 and the mul_mmq pipelines.
if (device->driver_id == vk::DriverId::eMesaHoneykrisp &&
pipeline->name.rfind("matmul", 0) == 0 &&
pipeline->name.find("bf16") == std::string::npos &&
pipeline->name.find("q8_1") == std::string::npos) {
const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data();
size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size();
std::vector<uint32_t> rolled;
if (ggml_vk_roll_bk_loop(src, src_n, rolled)) {
spirv = std::move(rolled);
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
}
}
#endif
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
vk::PushConstantRange pcr(
@@ -16188,22 +16301,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
}
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
// (and scaled down based on model size, so smaller models submit earlier).
int submitted_nodes = 0;
int submit_count = 0;
uint64_t mul_mat_bytes = 0;
uint64_t total_mul_mat_bytes = 0;
uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), ctx->last_total_mul_mat_bytes / 40u);
// Estimate the amount of compute work using flops, and submit every 200 GFLOP
// (and scaled down based on total graph flops, so smaller models submit earlier).
// Also submit at least every 100 nodes, in case there are workloads without heavy compute.
uint32_t submitted_nodes = 0;
uint32_t submit_count = 0;
uint64_t batch_flops = 0;
uint64_t total_flops = 0;
uint64_t flops_per_submit = std::min(uint64_t(200'000'000'000), ctx->last_total_flops / 40u);
for (int i = 0; i < cgraph->n_nodes; i++) {
if (first_node_in_batch) {
submit_node_idx = i;
}
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
auto bytes = ggml_nbytes(cgraph->nodes[i]->src[0]);
mul_mat_bytes += bytes;
total_mul_mat_bytes += bytes;
{
auto node_flops = ggml_vk_get_node_flops(cgraph->nodes[i]);
batch_flops += node_flops;
total_flops += node_flops;
}
// op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
@@ -16415,8 +16529,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) ||
(mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
bool submit = (submitted_nodes >= ctx->device->max_nodes_per_submit) ||
(flops_per_submit != 0 && batch_flops >= flops_per_submit) ||
(i + ctx->num_additional_fused_ops >= last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending);
@@ -16450,9 +16564,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
if (submit && enqueued) {
first_node_in_batch = true;
submitted_nodes = 0;
mul_mat_bytes = 0;
batch_flops = 0;
if (submit_count < 3) {
mul_mat_bytes_per_submit *= 2;
flops_per_submit *= 2;
}
submit_count++;
}
@@ -16461,7 +16575,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->fused_ops_write_mask = 0;
}
ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
ctx->last_total_flops = total_flops;
if (vk_perf_logger_enabled) {
// End the command buffer and submit/wait
@@ -1563,6 +1563,7 @@ class ggml_webgpu_shader_lib {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
{
// Quantized types using u32 buffers for portability.
defines.push_back("SRC_TYPE=u32");
@@ -1593,6 +1594,8 @@ class ggml_webgpu_shader_lib {
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) {
defines.push_back("BLOCK_SIZE=32u");
} else if (key.src_type == GGML_TYPE_NVFP4) {
defines.push_back("BLOCK_SIZE=64u");
} else if (key.src_type >= GGML_TYPE_Q2_K) {
defines.push_back("BLOCK_SIZE=256u");
} else {
@@ -1960,6 +1963,7 @@ class ggml_webgpu_shader_lib {
defines.push_back(type_upper + "_TABLES");
break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
defines.push_back(type_upper + "_LUT");
break;
default:
@@ -2103,6 +2107,7 @@ class ggml_webgpu_shader_lib {
defines.push_back(type_upper + "_TABLES");
break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
defines.push_back(type_upper + "_LUT");
break;
default:
@@ -2274,6 +2279,7 @@ class ggml_webgpu_shader_lib {
defines.push_back(type_upper + "_TABLES");
break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
defines.push_back(type_upper + "_LUT");
break;
default:
@@ -2394,6 +2400,7 @@ class ggml_webgpu_shader_lib {
defines.push_back(type_upper + "_TABLES");
break;
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
defines.push_back(type_upper + "_LUT");
break;
default:
+3
View File
@@ -4056,6 +4056,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
return true;
default:
return false;
@@ -4156,6 +4157,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
supports_op = true;
break;
default:
@@ -4196,6 +4198,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
supports_op = true;
break;
default:
@@ -896,9 +896,23 @@ const kvalues_iq4nl = array<i32, 16>(
#endif
#ifdef MXFP4_LUT
#if defined(MXFP4_LUT) || defined(NVFP4_LUT)
const kvalues_mxfp4 = array<i32, 16>(
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12
);
#endif
#endif // MXFP4_LUT || NVFP4_LUT
#ifdef NVFP4_LUT
fn ue4m3_to_fp32(u: u32) -> f32 {
if (u == 0u || u == 127u) {
return 0.0;
}
let exp = (u >> 3u) & 15u;
let man = u & 7u;
if (exp == 0u) {
return f32(man) * (1.0 / 512.0);
}
let bits = ((exp + 120u) << 23u) | (man << 20u);
return bitcast<f32>(bits);
}
#endif // NVFP4_LUT
@@ -672,6 +672,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
#endif
#ifdef NVFP4
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 36;
let d_word = load_u32_at_src(block_byte_base);
for (var sub: u32 = 0u; sub < 4; sub++) {
let d = ue4m3_to_fp32(get_byte(d_word, sub)) * 0.5;
for (var j: u32 = 0u; j < 2; j++) {
let q_packed = load_u32_at_src(block_byte_base + 4 + sub * 8 + j * 4);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d;
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d;
let dst_offset = dst_base + offset * 64 + sub * 16 + j * 4 + k;
dst[dst_offset] = q_lo;
dst[dst_offset + 8u] = q_hi;
}
}
}
}
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<SRC_TYPE>;
@@ -241,7 +241,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#endif // INIT_SRC0_SHMEM_Q8_1
#if defined(INIT_SRC0_SHMEM_MXFP4)
let block_byte_base = src0_idx * 17u;
let block_byte_base = src0_idx * 17u; // BLOCK_SIZE_BYTES = 17u;
let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
let e = ldexp(1.0, i32(eu8) - 128);
@@ -263,6 +263,47 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
#endif // legacy-quants
#if defined(INIT_SRC0_SHMEM_NVFP4)
const BLOCK_SIZE = 64u;
const BLOCK_SIZE_BYTES = 36u;
const SUB_BLOCK_SIZE = 16u; // elements sharing one UE4M3 scale
const NQ = 16u;
const BYTES_PER_THREAD = 8u;
const BYTES_PER_INNER_LOOP = 4u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let tile_m = i / TILE_K;
let tile_k_start = i % TILE_K;
let global_m = offset_m + tile_m;
let global_k_start = k_outer + tile_k_start;
if (global_m >= params.m) {
break;
}
let block_k = global_k_start / BLOCK_SIZE;
let sub_block = (global_k_start % BLOCK_SIZE) / SUB_BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d_byte_base = block_byte_base;
let qs_byte_base = block_byte_base + 4u;
let d = ue4m3_to_fp32(get_byte(load_u32_at_src0_aligned(d_byte_base), sub_block)) * 0.5;
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j++) {
let q_packed = load_u32_at_src0_aligned(qs_byte_base + sub_block * 8u + j * 4u);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
shmem[i + j * BYTES_PER_INNER_LOOP + k] = f16(f32(kvalues_mxfp4[q_byte & 0xF]) * d);
shmem[i + j * BYTES_PER_INNER_LOOP + k + 8u] = f16(f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d);
}
}
}
}
#endif // INIT_SRC0_SHMEM_NVFP4
// k-quants
#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
const BLOCK_SIZE = 256u;
@@ -1505,3 +1505,49 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src
return acc;
}
#endif
#ifdef MUL_ACC_NVFP4
#define BLOCK_SIZE 64
#define BLOCK_SIZE_BYTES 36
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
let num_blocks = params.k / BLOCK_SIZE;
let sub = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) {
let x_base = src1_idx_base + block * BLOCK_SIZE + sub * ELEMS_PER_THREAD;
var x_block: array<array<f32, ELEMS_PER_THREAD>, NUM_COLS>;
for (var col = 0u; col < NUM_COLS;col += 1) {
for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]);
x_block[col][i + 8] = f32(src1[x_base + col * params.stride_11 + i + 8]);
}
}
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = ue4m3_to_fp32(get_byte(load_u32_at_src0_aligned(block_byte_base), sub)) * 0.5;
let q_w0 = load_u32_at_src0_aligned(block_byte_base + 4u + 8u * sub);
let q_w1 = load_u32_at_src0_aligned(block_byte_base + 8u + 8u * sub);
for (var col = 0u;col < NUM_COLS;col += 1) {
var row_sum = 0.0;
for (var l = 0u; l < 8u; l++) {
let q_word = select(q_w0, q_w1, l >= 4u);
let q_byte = get_byte(q_word, l % 4u);
let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d;
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * d;
row_sum += q_lo * x_block[col][l];
row_sum += q_hi * x_block[col][l + 8u];
}
acc[col][row] += row_sum;
}
}
}
}
return acc;
}
#endif
+102 -2
View File
@@ -145,6 +145,7 @@ class Keys:
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
FULL_ATTENTION_INTERVAL = "{arch}.full_attention_interval"
HASH_LAYER_COUNT = "{arch}.hash_layer_count"
ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale"
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
@@ -180,8 +181,12 @@ class Keys:
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
OUTPUT_GROUP_COUNT = "{arch}.attention.output_group_count"
OUTPUT_LORA_RANK = "{arch}.attention.output_lora_rank"
OUTPUT_SCALE = "{arch}.attention.output_scale"
VALUE_SCALE = "{arch}.attention.value_scale"
COMPRESS_RATIOS = "{arch}.attention.compress_ratios"
COMPRESS_ROPE_FREQ_BASE = "{arch}.attention.compress_rope_freq_base"
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
@@ -196,6 +201,11 @@ class Keys:
KEY_LENGTH = "{arch}.attention.indexer.key_length"
TOP_K = "{arch}.attention.indexer.top_k"
class HyperConnection:
COUNT = "{arch}.hyper_connection.count"
SINKHORN_ITERATIONS = "{arch}.hyper_connection.sinkhorn_iterations"
EPSILON = "{arch}.hyper_connection.epsilon"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_COUNT_SWA = "{arch}.rope.dimension_count_swa"
@@ -470,6 +480,7 @@ class MODEL_ARCH(IntEnum):
DEEPSEEK2 = auto()
DEEPSEEK2OCR = auto()
DEEPSEEK32 = auto()
DEEPSEEK4 = auto()
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
@@ -555,6 +566,9 @@ class MODEL_TENSOR(IntEnum):
DENSE_2_OUT = auto() # embeddinggemma 2_Dense
DENSE_3_OUT = auto() # embeddinggemma 3_Dense
OUTPUT_NORM = auto()
HC_HEAD_FN = auto()
HC_HEAD_BASE = auto()
HC_HEAD_SCALE = auto()
ROPE_FREQS = auto()
ROPE_FACTORS_LONG = auto()
ROPE_FACTORS_SHORT = auto()
@@ -594,6 +608,7 @@ class MODEL_TENSOR(IntEnum):
FFN_DOWN_CHEXP = auto()
FFN_UP_CHEXP = auto()
FFN_EXP_PROBS_B = auto()
FFN_GATE_TID2EID = auto()
MOE_LATENT_DOWN = auto() # nemotron 3 super
MOE_LATENT_UP = auto() # nemotron 3 super
ATTN_Q_NORM = auto()
@@ -681,6 +696,20 @@ class MODEL_TENSOR(IntEnum):
ATTN_V_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
ATTN_KV = auto()
ATTN_KV_NORM = auto()
ATTN_OUT_A = auto()
ATTN_OUT_B = auto()
HC_ATTN_FN = auto()
HC_ATTN_BASE = auto()
HC_ATTN_SCALE = auto()
HC_FFN_FN = auto()
HC_FFN_BASE = auto()
HC_FFN_SCALE = auto()
ATTN_COMPRESSOR_WKV = auto()
ATTN_COMPRESSOR_WGATE = auto()
ATTN_COMPRESSOR_APE = auto()
ATTN_COMPRESSOR_NORM = auto()
FFN_SUB_NORM = auto()
ATTN_SUB_NORM = auto()
DEC_ATTN_NORM = auto()
@@ -742,6 +771,10 @@ class MODEL_TENSOR(IntEnum):
INDEXER_PROJ = auto()
INDEXER_ATTN_K = auto()
INDEXER_ATTN_Q_B = auto()
INDEXER_COMPRESSOR_WKV = auto()
INDEXER_COMPRESSOR_WGATE = auto()
INDEXER_COMPRESSOR_APE = auto()
INDEXER_COMPRESSOR_NORM = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
@@ -1027,6 +1060,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr",
MODEL_ARCH.DEEPSEEK32: "deepseek32",
MODEL_ARCH.DEEPSEEK4: "deepseek4",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe",
@@ -1111,6 +1145,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
MODEL_TENSOR.HC_HEAD_FN: "output_hc_fn",
MODEL_TENSOR.HC_HEAD_BASE: "output_hc_base",
MODEL_TENSOR.HC_HEAD_SCALE: "output_hc_scale",
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
@@ -1152,6 +1189,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.FFN_GATE_TID2EID: "blk.{bid}.ffn_gate_tid2eid",
MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super
MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
@@ -1237,6 +1275,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_KV: "blk.{bid}.attn_kv",
MODEL_TENSOR.ATTN_KV_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_OUT_A: "blk.{bid}.attn_output_a",
MODEL_TENSOR.ATTN_OUT_B: "blk.{bid}.attn_output_b",
MODEL_TENSOR.HC_ATTN_FN: "blk.{bid}.hc_attn_fn",
MODEL_TENSOR.HC_ATTN_BASE: "blk.{bid}.hc_attn_base",
MODEL_TENSOR.HC_ATTN_SCALE: "blk.{bid}.hc_attn_scale",
MODEL_TENSOR.HC_FFN_FN: "blk.{bid}.hc_ffn_fn",
MODEL_TENSOR.HC_FFN_BASE: "blk.{bid}.hc_ffn_base",
MODEL_TENSOR.HC_FFN_SCALE: "blk.{bid}.hc_ffn_scale",
MODEL_TENSOR.ATTN_COMPRESSOR_WKV: "blk.{bid}.attn_compressor_kv",
MODEL_TENSOR.ATTN_COMPRESSOR_WGATE: "blk.{bid}.attn_compressor_gate",
MODEL_TENSOR.ATTN_COMPRESSOR_APE: "blk.{bid}.attn_compressor_ape",
MODEL_TENSOR.ATTN_COMPRESSOR_NORM: "blk.{bid}.attn_compressor_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm",
@@ -1298,6 +1350,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj",
MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k",
MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b",
MODEL_TENSOR.INDEXER_COMPRESSOR_WKV: "blk.{bid}.indexer_compressor_kv",
MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE: "blk.{bid}.indexer_compressor_gate",
MODEL_TENSOR.INDEXER_COMPRESSOR_APE: "blk.{bid}.indexer_compressor_ape",
MODEL_TENSOR.INDEXER_COMPRESSOR_NORM: "blk.{bid}.indexer_compressor_norm",
# vision
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
@@ -3138,6 +3194,49 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.DEEPSEEK4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.HC_HEAD_FN,
MODEL_TENSOR.HC_HEAD_BASE,
MODEL_TENSOR.HC_HEAD_SCALE,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_SINKS,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV,
MODEL_TENSOR.ATTN_KV_NORM,
MODEL_TENSOR.ATTN_OUT_A,
MODEL_TENSOR.ATTN_OUT_B,
MODEL_TENSOR.HC_ATTN_FN,
MODEL_TENSOR.HC_ATTN_BASE,
MODEL_TENSOR.HC_ATTN_SCALE,
MODEL_TENSOR.HC_FFN_FN,
MODEL_TENSOR.HC_FFN_BASE,
MODEL_TENSOR.HC_FFN_SCALE,
MODEL_TENSOR.ATTN_COMPRESSOR_WKV,
MODEL_TENSOR.ATTN_COMPRESSOR_WGATE,
MODEL_TENSOR.ATTN_COMPRESSOR_APE,
MODEL_TENSOR.ATTN_COMPRESSOR_NORM,
MODEL_TENSOR.INDEXER_PROJ,
MODEL_TENSOR.INDEXER_ATTN_Q_B,
MODEL_TENSOR.INDEXER_COMPRESSOR_WKV,
MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE,
MODEL_TENSOR.INDEXER_COMPRESSOR_APE,
MODEL_TENSOR.INDEXER_COMPRESSOR_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_TID2EID,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.ERNIE4_5_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -4437,8 +4536,9 @@ class GGMLQuantizationType(IntEnum):
class ExpertGatingFuncType(IntEnum):
SOFTMAX = 1
SIGMOID = 2
SOFTMAX = 1
SIGMOID = 2
SQRTSOFTPLUS = 4
# TODO: add GGMLFileType from ggml_ftype in ggml.h
+24
View File
@@ -715,6 +715,9 @@ class GGUFWriter:
def add_full_attention_interval(self, interval: int) -> None:
self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval)
def add_hash_layer_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.HASH_LAYER_COUNT.format(arch=self.arch), count)
def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
if isinstance(length, int):
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
@@ -952,6 +955,27 @@ class GGUFWriter:
def add_norm_before_residual(self, value: bool) -> None:
self.add_bool(Keys.LLM.NORM_BEFORE_RESIDUAL.format(arch=self.arch), value)
def add_attention_output_group_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.OUTPUT_GROUP_COUNT.format(arch=self.arch), count)
def add_attention_output_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.OUTPUT_LORA_RANK.format(arch=self.arch), length)
def add_attention_compress_ratios(self, values: Sequence[int]) -> None:
self.add_array(Keys.Attention.COMPRESS_RATIOS.format(arch=self.arch), values)
def add_attention_compress_rope_freq_base(self, value: float) -> None:
self.add_float32(Keys.Attention.COMPRESS_ROPE_FREQ_BASE.format(arch=self.arch), value)
def add_hyper_connection_count(self, count: int) -> None:
self.add_uint32(Keys.HyperConnection.COUNT.format(arch=self.arch), count)
def add_hyper_connection_sinkhorn_iterations(self, count: int) -> None:
self.add_uint32(Keys.HyperConnection.SINKHORN_ITERATIONS.format(arch=self.arch), count)
def add_hyper_connection_epsilon(self, value: float) -> None:
self.add_float32(Keys.HyperConnection.EPSILON.format(arch=self.arch), value)
def add_attention_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
@@ -0,0 +1,112 @@
{%- if not add_generation_prompt is defined -%}
{%- set add_generation_prompt = false -%}
{%- endif -%}
{%- if not thinking is defined -%}
{%- if enable_thinking is defined -%}
{%- set thinking = enable_thinking -%}
{%- else -%}
{%- set thinking = false -%}
{%- endif -%}
{%- endif -%}
{%- set dsml_token = 'DSML' -%}
{%- set thinking_start_token = '<think>' -%}
{%- set thinking_end_token = '</think>' -%}
{%- set tools_header = '## Tools\n\nYou have access to a set of tools to help answer the user\'s question. You can invoke tools by writing a "<' + dsml_token + 'tool_calls>" block like the following:\n\n<' + dsml_token + 'tool_calls>\n<' + dsml_token + 'invoke name="$TOOL_NAME">\n<' + dsml_token + 'parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</' + dsml_token + 'parameter>\n...\n</' + dsml_token + 'invoke>\n<' + dsml_token + 'invoke name="$TOOL_NAME2">\n...\n</' + dsml_token + 'invoke>\n</' + dsml_token + 'tool_calls>\n\nString parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.\n\nIf thinking_mode is enabled (triggered by ' + thinking_start_token + '), you MUST output your complete reasoning inside ' + thinking_start_token + '...' + thinking_end_token + ' BEFORE any tool calls or final response.\n\nOtherwise, output directly after ' + thinking_end_token + ' with tool calls or final response.\n\n### Available Tool Schemas\n\n' -%}
{%- set tools_footer = '\nYou MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.\n' -%}
{%- set ns = namespace(system_prompt='', is_first_sp=true) -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- if ns.is_first_sp -%}
{%- set ns.system_prompt = ns.system_prompt + (message['content'] or '') -%}
{%- set ns.is_first_sp = false -%}
{%- else -%}
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + (message['content'] or '') -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if tools is defined and tools -%}
{%- set ts = namespace(schemas='') -%}
{%- for tool in tools -%}
{%- if tool['type'] == 'function' -%}
{%- set ts.schemas = ts.schemas + (tool['function'] | tojson) + '\n' -%}
{%- endif -%}
{%- endfor -%}
{%- if ns.system_prompt -%}
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + tools_header + ts.schemas + tools_footer -%}
{%- else -%}
{%- set ns.system_prompt = tools_header + ts.schemas + tools_footer -%}
{%- endif -%}
{%- endif -%}
{{- bos_token -}}
{{- ns.system_prompt -}}
{%- set last_user_idx = namespace(value=-1) -%}
{%- for message in messages -%}
{%- if message['role'] == 'user' or message['role'] == 'developer' or message['role'] == 'tool' -%}
{%- set last_user_idx.value = loop.index0 -%}
{%- endif -%}
{%- endfor -%}
{%- set state = namespace(in_user=false) -%}
{%- for message in messages -%}
{%- if message['role'] == 'user' or message['role'] == 'developer' -%}
{%- if state.in_user -%}
{{- '\n\n' -}}
{%- else -%}
{{- '<User>' -}}
{%- set state.in_user = true -%}
{%- endif -%}
{{- message['content'] or '' -}}
{%- elif message['role'] == 'tool' -%}
{%- if state.in_user -%}
{{- '\n\n' -}}
{%- else -%}
{{- '<User>' -}}
{%- set state.in_user = true -%}
{%- endif -%}
{{- '<tool_result>' + (message['content'] or '') + '</tool_result>' -}}
{%- elif message['role'] == 'assistant' -%}
{%- set state.in_user = false -%}
{{- '<Assistant>' -}}
{%- set is_after_last_user = loop.index0 > last_user_idx.value -%}
{%- if is_after_last_user and thinking -%}
{{- thinking_start_token -}}
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
{{- message['reasoning_content'] -}}
{%- endif -%}
{{- thinking_end_token -}}
{%- else -%}
{{- thinking_end_token -}}
{%- endif -%}
{%- if message['content'] is defined and message['content'] -%}
{{- message['content'] -}}
{%- endif -%}
{%- if message['tool_calls'] -%}
{{- '\n\n<' + dsml_token + 'tool_calls>\n' -}}
{%- for tool in message['tool_calls'] -%}
{%- set func = tool['function'] -%}
{{- '<' + dsml_token + 'invoke name="' + func['name'] + '">\n' -}}
{%- set args = func['arguments'] -%}
{%- if args is string -%}
{%- set args = args | from_json -%}
{%- endif -%}
{%- for key, val in args.items() -%}
{%- if val is string -%}
{{- '<' + dsml_token + 'parameter name="' + key + '" string="true">' + val + '</' + dsml_token + 'parameter>\n' -}}
{%- else -%}
{{- '<' + dsml_token + 'parameter name="' + key + '" string="false">' + (val | tojson) + '</' + dsml_token + 'parameter>\n' -}}
{%- endif -%}
{%- endfor -%}
{{- '</' + dsml_token + 'invoke>\n' -}}
{%- endfor -%}
{{- '</' + dsml_token + 'tool_calls>' -}}
{%- endif -%}
{{- '<end▁of▁sentence>' -}}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- '<Assistant>' -}}
{%- if thinking -%}
{{- thinking_start_token -}}
{%- else -%}
{{- thinking_end_token -}}
{%- endif -%}
{%- endif -%}
+4 -1
View File
@@ -69,13 +69,16 @@ mbuf=
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
fasel=
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
set -x
adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel $fasel \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 1024 -fa on \
+4 -1
View File
@@ -57,6 +57,9 @@ opfuse=
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
fasel=
[ "$FA" != "" ] && fasel="GGML_HEXAGON_FA_SELECT=$FA"
set -x
tool=$1; shift
@@ -65,5 +68,5 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel $fasel ./$branch/bin/$tool $@ \
"
@@ -230,6 +230,12 @@ def print_ascii_timeline(op_name, dims, types, usec, cycles, events, evt_val=Non
char = 'Q'
elif norm_evt == 'A-PREP':
char = 'A'
elif norm_evt == 'Q-PREP':
char = 'q'
elif norm_evt == 'K-PREP':
char = 'k'
elif norm_evt == 'V-PREP':
char = 'v'
elif norm_evt == 'W-DEQUANT':
char = 'D'
elif norm_evt == 'O-PROC':
+1
View File
@@ -25,6 +25,7 @@ add_library(llama
llama-kv-cache.cpp
llama-kv-cache-iswa.cpp
llama-kv-cache-dsa.cpp
llama-kv-cache-dsv4.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-hybrid-iswa.cpp
+56
View File
@@ -77,6 +77,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" },
{ LLM_ARCH_DEEPSEEK32, "deepseek32" },
{ LLM_ARCH_DEEPSEEK4, "deepseek4" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
@@ -250,9 +251,19 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
{ LLM_KV_ATTENTION_OUTPUT_GROUP_COUNT, "%s.attention.output_group_count" },
{ LLM_KV_ATTENTION_OUTPUT_LORA_RANK, "%s.attention.output_lora_rank" },
{ LLM_KV_ATTENTION_COMPRESS_ROPE_FREQ_BASE, "%s.attention.compress_rope_freq_base" },
{ LLM_KV_ATTENTION_COMPRESS_RATIOS, "%s.attention.compress_ratios" },
{ LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" },
{ LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" },
{ LLM_KV_HYPER_CONNECTION_COUNT, "%s.hyper_connection.count" },
{ LLM_KV_HYPER_CONNECTION_SINKHORN_ITERATIONS, "%s.hyper_connection.sinkhorn_iterations" },
{ LLM_KV_HYPER_CONNECTION_EPSILON, "%s.hyper_connection.epsilon" },
{ LLM_KV_HASH_LAYER_COUNT, "%s.hash_layer_count" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -440,6 +451,23 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_KV, "blk.%d.attn_kv" },
{ LLM_TENSOR_ATTN_KV_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_OUT_A, "blk.%d.attn_output_a" },
{ LLM_TENSOR_ATTN_OUT_B, "blk.%d.attn_output_b" },
{ LLM_TENSOR_HC_HEAD_FN, "output_hc_fn" },
{ LLM_TENSOR_HC_HEAD_BASE, "output_hc_base" },
{ LLM_TENSOR_HC_HEAD_SCALE, "output_hc_scale" },
{ LLM_TENSOR_HC_ATTN_FN, "blk.%d.hc_attn_fn" },
{ LLM_TENSOR_HC_ATTN_BASE, "blk.%d.hc_attn_base" },
{ LLM_TENSOR_HC_ATTN_SCALE, "blk.%d.hc_attn_scale" },
{ LLM_TENSOR_HC_FFN_FN, "blk.%d.hc_ffn_fn" },
{ LLM_TENSOR_HC_FFN_BASE, "blk.%d.hc_ffn_base" },
{ LLM_TENSOR_HC_FFN_SCALE, "blk.%d.hc_ffn_scale" },
{ LLM_TENSOR_ATTN_COMPRESSOR_WKV, "blk.%d.attn_compressor_kv" },
{ LLM_TENSOR_ATTN_COMPRESSOR_WGATE, "blk.%d.attn_compressor_gate" },
{ LLM_TENSOR_ATTN_COMPRESSOR_APE, "blk.%d.attn_compressor_ape" },
{ LLM_TENSOR_ATTN_COMPRESSOR_NORM, "blk.%d.attn_compressor_norm" },
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
@@ -566,6 +594,11 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_WKV, "blk.%d.indexer_compressor_kv" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, "blk.%d.indexer_compressor_gate" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_APE, "blk.%d.indexer_compressor_ape" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_NORM, "blk.%d.indexer_compressor_norm" },
{ LLM_TENSOR_FFN_GATE_TID2EID, "blk.%d.ffn_gate_tid2eid" },
{ LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" },
{ LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" },
{ LLM_TENSOR_FC, "fc" },
@@ -616,6 +649,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_OUT_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_HEAD_FN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_HEAD_BASE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_ADD}},
{LLM_TENSOR_HC_HEAD_SCALE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_HC_ATTN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_ATTN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_HC_ATTN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_HC_FFN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_FFN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_HC_FFN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_ATTN_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}},
@@ -779,6 +829,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_INDEXER_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_GATE_TID2EID, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
@@ -933,6 +988,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) {
case LLM_ARCH_OLMOE:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK32:
case LLM_ARCH_DEEPSEEK4:
case LLM_ARCH_GLM_DSA:
case LLM_ARCH_BITNET:
case LLM_ARCH_T5:
+33
View File
@@ -82,6 +82,7 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_DEEPSEEK2OCR,
LLM_ARCH_DEEPSEEK32,
LLM_ARCH_DEEPSEEK4,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
@@ -255,9 +256,19 @@ enum llm_kv {
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
LLM_KV_ATTENTION_INDEXER_TOP_K,
LLM_KV_ATTENTION_OUTPUT_GROUP_COUNT,
LLM_KV_ATTENTION_OUTPUT_LORA_RANK,
LLM_KV_ATTENTION_COMPRESS_ROPE_FREQ_BASE,
LLM_KV_ATTENTION_COMPRESS_RATIOS,
LLM_KV_ATTENTION_SHARED_KV_LAYERS,
LLM_KV_ATTENTION_RECURRENT_LAYERS,
LLM_KV_HYPER_CONNECTION_COUNT,
LLM_KV_HYPER_CONNECTION_SINKHORN_ITERATIONS,
LLM_KV_HYPER_CONNECTION_EPSILON,
LLM_KV_HASH_LAYER_COUNT,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -501,10 +512,27 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_KV,
LLM_TENSOR_ATTN_KV_NORM,
LLM_TENSOR_ATTN_OUT_A,
LLM_TENSOR_ATTN_OUT_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_HC_HEAD_FN,
LLM_TENSOR_HC_HEAD_BASE,
LLM_TENSOR_HC_HEAD_SCALE,
LLM_TENSOR_HC_ATTN_FN,
LLM_TENSOR_HC_ATTN_BASE,
LLM_TENSOR_HC_ATTN_SCALE,
LLM_TENSOR_HC_FFN_FN,
LLM_TENSOR_HC_FFN_BASE,
LLM_TENSOR_HC_FFN_SCALE,
LLM_TENSOR_ATTN_COMPRESSOR_WKV,
LLM_TENSOR_ATTN_COMPRESSOR_WGATE,
LLM_TENSOR_ATTN_COMPRESSOR_APE,
LLM_TENSOR_ATTN_COMPRESSOR_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
LLM_TENSOR_FFN_SUB_NORM,
LLM_TENSOR_DEC_ATTN_NORM,
@@ -566,6 +594,11 @@ enum llm_tensor {
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_INDEXER_COMPRESSOR_WKV,
LLM_TENSOR_INDEXER_COMPRESSOR_WGATE,
LLM_TENSOR_INDEXER_COMPRESSOR_APE,
LLM_TENSOR_INDEXER_COMPRESSOR_NORM,
LLM_TENSOR_FFN_GATE_TID2EID,
LLM_TENSOR_NEXTN_PROJ_PRE,
LLM_TENSOR_NEXTN_PROJ_POST,
LLM_TENSOR_NEXTN_EH_PROJ,
+5 -1
View File
@@ -2321,7 +2321,11 @@ void llama_context::output_reorder() {
//
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) {
if (model.arch == LLM_ARCH_QWEN3NEXT ||
model.arch == LLM_ARCH_KIMI_LINEAR ||
model.arch == LLM_ARCH_QWEN35 ||
model.arch == LLM_ARCH_QWEN35MOE ||
model.arch == LLM_ARCH_DEEPSEEK4) {
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
}
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
+352 -23
View File
@@ -8,6 +8,7 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-kv-cache-dsa.h"
#include "llama-kv-cache-dsv4.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
@@ -17,6 +18,7 @@
#include <cstring>
#include <numeric>
#include <sstream>
#include <string>
#include <unordered_set>
// dedup helpers
@@ -568,7 +570,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
// base tensors may not be allocated if there are no non-SWA attention layers
if (self_k_idxs && self_k_idxs->buffer) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
if (self_v_idxs) {
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
}
}
// the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
@@ -579,7 +583,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
// swa tensors may not be allocated if there are no SWA attention layers
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
if (self_v_idxs_swa) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
}
}
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
@@ -633,6 +639,283 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
return res;
}
static void dsv4_set_i64(ggml_tensor * dst, const std::vector<int64_t> & src) {
if (!dst || !dst->buffer) {
return;
}
GGML_ASSERT(dst->ne[0] == (int64_t) src.size());
ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst));
}
static void dsv4_set_i32(ggml_tensor * dst, const std::vector<int32_t> & src) {
if (!dst || !dst->buffer) {
return;
}
GGML_ASSERT(dst->ne[0] == (int64_t) src.size());
ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst));
}
static void dsv4_set_kq_mask(
ggml_tensor * dst,
const llama_kv_cache_dsv4_context::comp_plan & plan,
uint32_t n_tokens,
int64_t n_stream) {
if (!dst || !dst->buffer) {
return;
}
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(n_stream > 0);
GGML_ASSERT(n_tokens%n_stream == 0);
GGML_ASSERT(dst->ne[0] == plan.n_kv);
GGML_ASSERT(dst->ne[1] == (int64_t) n_tokens/n_stream);
GGML_ASSERT(dst->ne[2] == 1);
GGML_ASSERT(dst->ne[3] == n_stream);
GGML_ASSERT((int64_t) plan.n_visible.size() == (int64_t) n_tokens);
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data;
for (int64_t i = 0; i < (int64_t) n_tokens; ++i) {
const int32_t n_visible = plan.n_visible[i];
for (int64_t j = 0; j < dst->ne[0]; ++j) {
data[i*dst->ne[0] + j] = j < n_visible ? 0.0f : -INFINITY;
}
}
}
static ggml_tensor * dsv4_build_raw_kq_mask(
ggml_context * ctx,
const llama_kv_cache_dsv4_raw_context * mctx,
const llama_ubatch & ubatch,
const llama_cparams & cparams,
int64_t n_stream) {
const auto n_kv = mctx->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_stream > 0);
GGML_ASSERT(n_tokens%n_stream == 0);
const bool use_fattn = cparams.flash_attn && (!cparams.kv_unified || n_stream == 1);
const auto type = use_fattn ? GGML_TYPE_F16 : GGML_TYPE_F32;
ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_kq_mask");
return res;
}
static bool dsv4_can_reuse_raw_kq_mask(
ggml_tensor * kq_mask,
const llama_kv_cache_dsv4_raw_context * mctx,
const llama_ubatch & ubatch,
int64_t n_stream) {
const auto n_kv = mctx->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_stream > 0);
bool res = true;
res &= (kq_mask->ne[0] == n_kv);
res &= (kq_mask->ne[1] == n_tokens/n_stream);
res &= (kq_mask->ne[2] == 1);
res &= (kq_mask->ne[3] == n_stream);
return res;
}
static std::string dsv4_plan_positions(const std::vector<int32_t> & values) {
std::ostringstream ss;
ss << "[";
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) {
ss << ", ";
}
ss << values[i];
}
ss << "]";
return ss.str();
}
static bool dsv4_compress_debug() {
static const bool debug = []() {
const char * env = getenv("LLAMA_DSV4_COMPRESS_DEBUG");
return env && atoi(env) > 0;
}();
return debug;
}
static void dsv4_set_comp_inputs(
const llm_graph_input_dsv4::comp_input & inp,
const llama_kv_cache_dsv4_context::comp_plan & plan,
const char * name,
bool debug,
uint32_t n_tokens,
int64_t n_stream) {
dsv4_set_i32(inp.state_pos, plan.state_pos);
dsv4_set_i32(inp.state_persist_src_idxs, plan.state_persist_src_idxs);
dsv4_set_i32(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs);
dsv4_set_i32(inp.state_read_idxs, plan.state_read_idxs);
dsv4_set_i64(inp.state_write_idxs, plan.state_write_idxs);
dsv4_set_i32(inp.state_write_pos, plan.state_write_pos);
dsv4_set_kq_mask(inp.kq_mask, plan, n_tokens, n_stream);
if (debug || dsv4_compress_debug()) {
LLAMA_LOG_INFO("%s: %s n_tokens=%u, n_stream=%d, state_persist_dst=%s, state_write_pos=%s\n",
__func__, name, n_tokens, (int) n_stream,
dsv4_plan_positions(plan.state_persist_dst_idxs).c_str(),
dsv4_plan_positions(plan.state_write_pos).c_str());
}
}
static bool dsv4_can_reuse_tensor_1d(ggml_tensor * t, int64_t ne0) {
return (t == nullptr && ne0 == 0) || (t != nullptr && t->ne[0] == ne0);
}
static bool dsv4_can_reuse_kq_mask(
ggml_tensor * t,
const llama_kv_cache_dsv4_context::comp_plan & plan,
uint32_t n_tokens,
int64_t n_stream) {
if (plan.n_kv == 0) {
return t == nullptr;
}
GGML_ASSERT(n_stream > 0);
return t != nullptr &&
t->ne[0] == plan.n_kv &&
t->ne[1] == (int64_t) n_tokens/n_stream &&
t->ne[2] == 1 &&
t->ne[3] == n_stream;
}
static bool dsv4_can_reuse_comp_input(
const llm_graph_input_dsv4::comp_input & inp,
const llama_kv_cache_dsv4_context::comp_plan & plan,
uint32_t n_tokens,
int64_t n_stream) {
bool res = true;
res &= dsv4_can_reuse_tensor_1d(inp.state_pos, plan.state_pos.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_persist_src_idxs, plan.state_persist_src_idxs.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_read_idxs, plan.state_read_idxs.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_write_idxs, plan.state_write_idxs.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_write_pos, plan.state_write_pos.size());
res &= dsv4_can_reuse_kq_mask(inp.kq_mask, plan, n_tokens, n_stream);
return res;
}
static ggml_tensor * dsv4_build_input_1d(
ggml_context * ctx,
ggml_type type,
int64_t ne0,
const std::string & name) {
if (ne0 == 0) {
return nullptr;
}
ggml_tensor * res = ggml_new_tensor_1d(ctx, type, ne0);
ggml_set_input(res);
ggml_set_name(res, name.c_str());
return res;
}
static void dsv4_build_comp_inputs(
ggml_context * ctx,
llm_graph_input_dsv4::comp_input & inp,
const llama_kv_cache_dsv4_context::comp_plan & plan,
const char * name,
int64_t n_stream) {
inp.state_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_pos.size(), std::string("dsv4_") + name + "_state_pos");
inp.state_persist_src_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_src_idxs.size(), std::string("dsv4_") + name + "_state_persist_src_idxs");
inp.state_persist_dst_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_dst_idxs.size(), std::string("dsv4_") + name + "_state_persist_dst_idxs");
inp.state_read_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_read_idxs.size(), std::string("dsv4_") + name + "_state_read_idxs");
inp.state_write_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I64, plan.state_write_idxs.size(), std::string("dsv4_") + name + "_state_write_idxs");
inp.state_write_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_write_pos.size(), std::string("dsv4_") + name + "_state_write_pos");
if (plan.n_kv > 0) {
const int64_t n_tokens = (int64_t) plan.n_visible.size();
GGML_ASSERT(n_stream > 0);
GGML_ASSERT(n_tokens%n_stream == 0);
inp.kq_mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, plan.n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp.kq_mask);
ggml_set_name(inp.kq_mask, (std::string("dsv4_") + name + "_kq_mask").c_str());
}
}
void llm_graph_input_dsv4_raw::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs && self_k_idxs->buffer) {
mctx->set_input_k_idxs(self_k_idxs);
}
if (self_kq_mask && self_kq_mask->buffer) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_k_rot) {
mctx->set_input_k_rot(self_k_rot);
}
}
void llm_graph_input_dsv4::set_input(const llama_ubatch * ubatch) {
const auto & plan_csa = mctx->get_csa_plan(*ubatch);
const auto & plan_hca = mctx->get_hca_plan(*ubatch);
const auto & plan_lid = mctx->get_lid_plan(*ubatch);
const int64_t n_stream = plan_csa.n_stream;
inp_raw->mctx = mctx->get_raw();
inp_raw->set_input(ubatch);
dsv4_set_comp_inputs(inp_csa, plan_csa, "csa", debug > 0, ubatch->n_tokens, n_stream);
dsv4_set_comp_inputs(inp_hca, plan_hca, "hca", debug > 0, ubatch->n_tokens, n_stream);
dsv4_set_comp_inputs(inp_lid, plan_lid, "lid", debug > 0, ubatch->n_tokens, n_stream);
if (inp_lid.k_rot && inp_lid.k_rot->buffer) {
mctx->get_lid()->set_input_k_rot(inp_lid.k_rot);
}
}
bool llm_graph_input_dsv4::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_dsv4_context *>(params.mctx);
this->mctx = mctx;
inp_raw->mctx = mctx->get_raw();
bool res = true;
const auto & plan_csa = mctx->get_csa_plan(params.ubatch);
const auto & plan_hca = mctx->get_hca_plan(params.ubatch);
const auto & plan_lid = mctx->get_lid_plan(params.ubatch);
const int64_t n_stream = plan_csa.n_stream;
const auto * raw_ctx = mctx->get_raw();
inp_raw->mctx = raw_ctx;
if (inp_raw->self_k_idxs && inp_raw->self_k_idxs->buffer) {
res &= inp_raw->self_k_idxs->ne[0] == raw_ctx->get_n_write();
}
if (inp_raw->self_kq_mask && inp_raw->self_kq_mask->buffer) {
res &= dsv4_can_reuse_raw_kq_mask(inp_raw->self_kq_mask, raw_ctx, params.ubatch, n_stream);
}
res &= dsv4_can_reuse_comp_input(inp_csa, plan_csa, params.ubatch.n_tokens, n_stream);
res &= dsv4_can_reuse_comp_input(inp_hca, plan_hca, params.ubatch.n_tokens, n_stream);
res &= dsv4_can_reuse_comp_input(inp_lid, plan_lid, params.ubatch.n_tokens, n_stream);
return res;
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(cross_kq_mask);
@@ -1351,20 +1634,24 @@ ggml_tensor * llm_graph_context::build_ffn(
switch (type_op) {
case LLM_FFN_SILU:
if (gate && type_gate == LLM_FFN_PAR) {
// Step35: HF clamps gate (after SiLU) and up before multiplication
if (arch == LLM_ARCH_STEP35 && il >= 0) {
if (il >= 0) {
const float limit = hparams.swiglu_clamp_shexp[il];
constexpr float eps = 1e-6f;
if (limit > eps) {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_silu_clamped", il);
tmp = ggml_clamp(ctx0, tmp, -limit, limit);
cb(tmp, "ffn_up_clamped", il);
cur = ggml_mul(ctx0, gate_act, tmp);
if (arch == LLM_ARCH_DEEPSEEK4) {
cur = ggml_clamp(ctx0, cur, -INFINITY, limit);
cb(cur, "ffn_gate_clamped", il);
cur = ggml_swiglu_split(ctx0, cur, tmp);
} else {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_silu_clamped", il);
cur = ggml_mul(ctx0, gate_act, tmp);
}
cb(cur, "ffn_swiglu_limited", il);
type_gate = LLM_FFN_SEQ;
break;
@@ -1474,7 +1761,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * gate_up_exps,
ggml_tensor * up_exps_s,
ggml_tensor * gate_exps_s,
ggml_tensor * down_exps_s) const {
ggml_tensor * down_exps_s,
ggml_tensor * selected_experts_in) const {
return build_moe_ffn(
cur,
gate_inp, /* gate_inp_b */ nullptr,
@@ -1494,7 +1782,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
/* gate_up_exps_b */ nullptr,
up_exps_s,
gate_exps_s,
down_exps_s
down_exps_s,
selected_experts_in
);
}
@@ -1521,7 +1810,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * gate_up_exps_b,
ggml_tensor * up_exps_s,
ggml_tensor * gate_exps_s,
ggml_tensor * down_exps_s) const {
ggml_tensor * down_exps_s,
ggml_tensor * selected_experts_in) const {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -1530,6 +1820,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
if (probs_in == nullptr) {
logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS) {
ggml_mul_mat_set_prec(logits, GGML_PREC_F32);
}
cb(logits, "ffn_moe_logits", il);
} else {
logits = probs_in;
@@ -1554,6 +1847,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
{
probs = logits; // [n_expert, n_tokens]
} break;
case LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS:
{
probs = ggml_sqrt(ctx0, ggml_softplus(ctx0, logits)); // [n_expert, n_tokens]
} break;
default:
GGML_ABORT("fatal error");
}
@@ -1604,8 +1901,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
}
// select experts
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
ggml_tensor * selected_experts = selected_experts_in;
if (selected_experts == nullptr) {
selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
}
cb(selected_experts, "ffn_moe_topk", il);
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
@@ -1718,20 +2018,24 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
switch (type_op) {
case LLM_FFN_SILU:
if (gate_exps) {
// Step35: per-layer clamp for routed experts
if (arch == LLM_ARCH_STEP35 && il >= 0) {
if (il >= 0) {
const float limit = hparams.swiglu_clamp_exp[il];
constexpr float eps = 1e-6f;
if (limit > eps) {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_moe_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_moe_silu_clamped", il);
up = ggml_clamp(ctx0, up, -limit, limit);
cb(up, "ffn_moe_up_clamped", il);
cur = ggml_mul(ctx0, gate_act, up);
if (arch == LLM_ARCH_DEEPSEEK4) {
cur = ggml_clamp(ctx0, cur, -INFINITY, limit);
cb(cur, "ffn_moe_gate_clamped", il);
cur = ggml_swiglu_split(ctx0, cur, up);
} else {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_moe_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_moe_silu_clamped", il);
cur = ggml_mul(ctx0, gate_act, up);
}
cb(cur, "ffn_moe_swiglu_limited", il);
break;
}
@@ -2760,6 +3064,31 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}
llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_dsv4_context *>(mctx);
const auto * raw_ctx = mctx_cur->get_raw();
auto inp_raw = std::make_unique<llm_graph_input_dsv4_raw>(cparams, raw_ctx);
const int64_t n_stream = mctx_cur->get_csa_plan(ubatch).n_stream;
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "DSV4 expects SWA raw cache");
inp_raw->self_k_idxs = raw_ctx->build_input_k_idxs(ctx0, ubatch);
inp_raw->self_kq_mask = dsv4_build_raw_kq_mask(ctx0, raw_ctx, ubatch, cparams, n_stream);
inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask;
inp_raw->self_k_rot = raw_ctx->build_input_k_rot(ctx0);
auto inp = std::make_unique<llm_graph_input_dsv4>(cparams, std::move(inp_raw), mctx_cur);
dsv4_build_comp_inputs(ctx0, inp->inp_csa, mctx_cur->get_csa_plan(ubatch), "csa", n_stream);
dsv4_build_comp_inputs(ctx0, inp->inp_hca, mctx_cur->get_hca_plan(ubatch), "hca", n_stream);
dsv4_build_comp_inputs(ctx0, inp->inp_lid, mctx_cur->get_lid_plan(ubatch), "lid", n_stream);
inp->inp_lid.k_rot = mctx_cur->get_lid()->build_input_k_rot(ctx0);
return (llm_graph_input_dsv4 *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * s,
ggml_tensor * state_copy_main,
+81 -2
View File
@@ -23,6 +23,8 @@ struct llama_memory_context_i;
class llama_kv_cache_context;
class llama_kv_cache_dsa_context;
class llama_kv_cache_dsv4_raw_context;
class llama_kv_cache_dsv4_context;
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
@@ -459,6 +461,79 @@ public:
const llama_kv_cache_iswa_context * mctx;
};
// DSV4 raw graph inputs are SWA-only, but their mask may be stream-shaped
// so raw K can be concatenated with DSV4 compressed K in one attention op.
class llm_graph_input_dsv4_raw {
public:
llm_graph_input_dsv4_raw(
const llama_cparams & cparams,
const llama_kv_cache_dsv4_raw_context * mctx) :
cparams(cparams),
mctx(mctx) {
}
void set_input(const llama_ubatch * ubatch);
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_k_rot = nullptr;
const llama_cparams cparams;
const llama_kv_cache_dsv4_raw_context * mctx;
};
class llm_graph_input_dsv4 : public llm_graph_input_i {
public:
struct comp_input {
ggml_tensor * state_pos = nullptr; // I32 [n_state]
ggml_tensor * state_persist_src_idxs = nullptr; // I32 [n_state_persist]
ggml_tensor * state_persist_dst_idxs = nullptr; // I32 [n_state_persist]
ggml_tensor * state_read_idxs = nullptr; // I32 [ratio*n_state_write]
ggml_tensor * state_write_idxs = nullptr; // I64 [n_state_write]
ggml_tensor * state_write_pos = nullptr; // I32 [n_state_write]
ggml_tensor * kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * k_rot = nullptr;
};
llm_graph_input_dsv4(
const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_dsv4_raw> inp_raw,
const llama_kv_cache_dsv4_context * mctx) :
inp_raw(std::move(inp_raw)),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_dsv4() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
llm_graph_input_dsv4_raw * get_raw() const { return inp_raw.get(); }
const comp_input & get_csa() const { return inp_csa; }
const comp_input & get_hca() const { return inp_hca; }
const comp_input & get_lid() const { return inp_lid; }
std::unique_ptr<llm_graph_input_dsv4_raw> inp_raw;
comp_input inp_csa;
comp_input inp_hca;
comp_input inp_lid;
const llama_cparams cparams;
const llama_kv_cache_dsv4_context * mctx;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
public:
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@@ -920,7 +995,8 @@ struct llm_graph_context {
ggml_tensor * gate_up_exps = nullptr,
ggml_tensor * up_exps_s = nullptr,
ggml_tensor * gate_exps_s = nullptr,
ggml_tensor * down_exps_s = nullptr) const;
ggml_tensor * down_exps_s = nullptr,
ggml_tensor * selected_experts_in = nullptr) const;
ggml_tensor * build_moe_ffn(
ggml_tensor * cur,
@@ -945,7 +1021,8 @@ struct llm_graph_context {
ggml_tensor * gate_up_exps_b = nullptr,
ggml_tensor * up_exps_s = nullptr,
ggml_tensor * gate_exps_s = nullptr,
ggml_tensor * down_exps_s = nullptr) const;
ggml_tensor * down_exps_s = nullptr,
ggml_tensor * selected_experts_in = nullptr) const;
//
// inputs
@@ -1045,6 +1122,8 @@ struct llm_graph_context {
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
llm_graph_input_dsv4 * build_inp_dsv4() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_iswa * inp,
+11
View File
@@ -14,6 +14,7 @@ enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits
LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS = 4,
};
enum llama_swa_type {
@@ -226,6 +227,16 @@ struct llama_hparams {
uint32_t indexer_head_size = 0;
uint32_t indexer_top_k = 0;
// DeepSeek-V4
uint32_t dsv4_o_group_count = 0;
uint32_t dsv4_o_lora_rank = 0;
uint32_t dsv4_hc_mult = 0;
uint32_t dsv4_hc_sinkhorn_iters = 0;
uint32_t dsv4_hash_layer_count = 0;
float dsv4_compress_rope_base = 0.0f;
float dsv4_hc_eps = 0.0f;
std::array<uint32_t, LLAMA_MAX_LAYERS> dsv4_compress_ratios;
// qwen3vl deepstack
// When parsed from GGUF, this implies the first N layers consume the first
// N deepstack embeddings. Use deepstack_mapping_arr if you need a more
File diff suppressed because it is too large Load Diff
+362
View File
@@ -0,0 +1,362 @@
#pragma once
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include <map>
#include <memory>
#include <unordered_map>
#include <vector>
class llama_dsv4_comp_state {
public:
llama_dsv4_comp_state(
const llama_model & model,
bool offload,
bool unified,
uint32_t n_seq_max,
uint32_t ratio,
uint32_t state_size,
uint32_t n_embd_state,
const char * name,
const llama_memory_i::layer_filter_cb & filter);
void clear(bool data);
uint32_t get_ratio() const;
uint32_t get_state_size() const;
uint32_t get_n_stream() const;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const;
void state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
ggml_tensor * get_kv (ggml_context * ctx, int32_t il) const;
ggml_tensor * get_score(ggml_context * ctx, int32_t il) const;
ggml_tensor * cpy_kv (ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const;
ggml_tensor * cpy_score(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const;
private:
struct layer {
uint32_t il;
ggml_tensor * kv;
ggml_tensor * score;
};
const uint32_t ratio;
const uint32_t state_size;
const uint32_t n_embd_state;
const uint32_t n_stream;
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
std::vector<layer> layers;
std::unordered_map<int32_t, int32_t> map_layer_ids;
size_t total_size() const;
};
//
// llama_kv_cache_dsv4
//
// DSV4 uses a normal raw/SWA token cache plus compressed K-only block caches.
// The compressed caches are storage only; DSV4-specific visibility and block
// planning are handled by llama_kv_cache_dsv4_context / llm_graph_input_dsv4.
class llama_kv_cache_dsv4 : public llama_memory_i {
public:
llama_kv_cache_dsv4(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache_dsv4() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_dsv4 specific API
//
llama_kv_cache_iswa * get_raw() const;
llama_kv_cache * get_csa() const;
llama_kv_cache * get_hca() const;
llama_kv_cache * get_lid() const;
llama_dsv4_comp_state * get_csa_state() const;
llama_dsv4_comp_state * get_hca_state() const;
llama_dsv4_comp_state * get_lid_state() const;
private:
llama_hparams hparams_raw;
llama_hparams hparams_csa;
llama_hparams hparams_hca;
llama_hparams hparams_lid;
const uint32_t n_seq_max;
std::unique_ptr<llama_kv_cache_iswa> kv_raw;
std::unique_ptr<llama_kv_cache> kv_csa;
std::unique_ptr<llama_kv_cache> kv_hca;
std::unique_ptr<llama_kv_cache> kv_lid;
std::unique_ptr<llama_dsv4_comp_state> csa_state;
std::unique_ptr<llama_dsv4_comp_state> hca_state;
std::unique_ptr<llama_dsv4_comp_state> lid_state;
void clear_compressed(bool data);
};
// DSV4 raw attention only uses the SWA half of kv_raw. The base half is kept
// for generic ISWA bookkeeping, but it has no DSV4 layers to expose here.
class llama_kv_cache_dsv4_raw_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
llama_kv_cache_dsv4_raw_context(llama_kv_cache_iswa * kv);
llama_kv_cache_dsv4_raw_context(
llama_kv_cache_iswa * kv,
llama_context * lctx,
bool optimize);
llama_kv_cache_dsv4_raw_context(
llama_kv_cache_iswa * kv,
slot_info_vec_t sinfos_base_write,
slot_info_vec_t sinfos_swa_write,
slot_info_vec_t sinfos_swa_read,
std::vector<llama_ubatch> ubatches,
std::vector<llama_ubatch> ubatches_write);
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
uint32_t get_n_kv() const;
uint32_t get_n_write() const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
void set_input_k_idxs(ggml_tensor * dst) const;
void set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_rot(ggml_tensor * dst) const;
private:
size_t i_next = 0;
llama_kv_cache * kv_swa = nullptr;
slot_info_vec_t sinfos_write;
slot_info_vec_t sinfos_read;
std::vector<llama_ubatch> ubatches;
std::vector<llama_ubatch> ubatches_write;
const llama_memory_context_ptr ctx_base_mem;
const llama_memory_context_ptr ctx_swa_mem;
uint32_t n_kv = 0;
const llama_memory_status status;
};
// DSV4 compressed KV rows are graph outputs, not normal token KV writes.
// Keep a small context that exposes K tensors without generic apply() semantics.
class llama_kv_cache_dsv4_comp_context {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
llama_kv_cache_dsv4_comp_context(llama_kv_cache * kv);
llama_kv_cache_dsv4_comp_context(
llama_kv_cache * kv,
slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches);
bool next();
uint32_t get_n_kv() const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
void set_input_k_rot(ggml_tensor * dst) const;
private:
llama_kv_cache * kv;
size_t i_cur = 0;
slot_info_vec_t sinfos;
std::vector<llama_ubatch> ubatches;
uint32_t n_kv;
};
class llama_kv_cache_dsv4_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
struct comp_plan {
// Per-ubatch recipe for updating compressor state, committing completed
// compressed rows, and masking the compressed attention source.
// APE row ids, i.e. pos % ratio, for the compressor-state updates.
std::vector<int32_t> state_pos;
// Current-ubatch source row ids and unique persistent-state
// destination row ids for deterministic ring-state updates.
std::vector<int32_t> state_persist_src_idxs;
std::vector<int32_t> state_persist_dst_idxs;
// Flattened source row ids used for state-backed commits. Source rows
// index the graph-local [persistent_state | current_ubatch_scratch]
// tensor. For overlapped compression the first half is previous rows
// and the second half is current rows; a final synthetic zero/-inf row
// may be addressed for the first block's previous half.
std::vector<int32_t> state_read_idxs;
// Final compressed-cache row ids written by state-backed commits.
// A non-boundary CSA/LID decode step can target a masked scratch row.
std::vector<int64_t> state_write_idxs;
// RoPE positions for state-backed commits.
std::vector<int32_t> state_write_pos;
// Number of completed compressed rows visible for each query token.
std::vector<int32_t> n_visible;
// Number of streams used by the attention graph for this ubatch.
int64_t n_stream = 1;
// Graph-width for compressed rows. This can be larger than n_visible
// so masked padding rows do not force a new graph at every CSA block.
int64_t n_kv = 0;
};
llama_kv_cache_dsv4_context(llama_memory_status status);
llama_kv_cache_dsv4_context(
llama_kv_cache_dsv4 * kv);
llama_kv_cache_dsv4_context(
llama_kv_cache_dsv4 * kv,
llama_context * lctx,
bool optimize);
llama_kv_cache_dsv4_context(
llama_kv_cache_dsv4 * kv,
slot_info_vec_t sinfos_raw_base_write,
slot_info_vec_t sinfos_raw_swa_write,
slot_info_vec_t sinfos_raw_swa_read,
std::vector<llama_ubatch> ubatches,
std::vector<llama_ubatch> ubatches_raw);
virtual ~llama_kv_cache_dsv4_context();
//
// llama_memory_context_i
//
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_dsv4_context specific API
//
const llama_kv_cache_dsv4_raw_context * get_raw() const;
const llama_kv_cache_dsv4_comp_context * get_csa() const;
const llama_kv_cache_dsv4_comp_context * get_hca() const;
const llama_kv_cache_dsv4_comp_context * get_lid() const;
const llama_dsv4_comp_state * get_csa_state() const;
const llama_dsv4_comp_state * get_hca_state() const;
const llama_dsv4_comp_state * get_lid_state() const;
const comp_plan & get_csa_plan() const;
const comp_plan & get_hca_plan() const;
const comp_plan & get_lid_plan() const;
const comp_plan & get_csa_plan(const llama_ubatch & ubatch) const;
const comp_plan & get_hca_plan(const llama_ubatch & ubatch) const;
const comp_plan & get_lid_plan(const llama_ubatch & ubatch) const;
private:
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
std::vector<comp_plan> plans_csa;
std::vector<comp_plan> plans_hca;
std::vector<comp_plan> plans_lid;
const std::unique_ptr<llama_kv_cache_dsv4_raw_context> ctx_raw;
const llama_memory_context_ptr ctx_csa_mem;
const llama_memory_context_ptr ctx_hca_mem;
const llama_memory_context_ptr ctx_lid_mem;
const std::unique_ptr<llama_kv_cache_dsv4_comp_context> ctx_csa;
const std::unique_ptr<llama_kv_cache_dsv4_comp_context> ctx_hca;
const std::unique_ptr<llama_kv_cache_dsv4_comp_context> ctx_lid;
const llama_dsv4_comp_state * csa_state = nullptr;
const llama_dsv4_comp_state * hca_state = nullptr;
const llama_dsv4_comp_state * lid_state = nullptr;
bool reserve_plans = false;
mutable comp_plan reserve_plan_csa;
mutable comp_plan reserve_plan_hca;
mutable comp_plan reserve_plan_lid;
const llama_memory_status status;
};
+22 -1
View File
@@ -26,7 +26,28 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse,
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {
const layer_share_cb & share) :
llama_kv_cache_iswa(model, model.hparams, type_k, type_v, v_trans, offload, swa_full, unified,
kv_size, n_seq_max, n_ubatch, n_pad, mem_other, filter, reuse, share) {
}
llama_kv_cache_iswa::llama_kv_cache_iswa(
const llama_model & model,
const llama_hparams & hparams,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse,
const layer_share_cb & share) : unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
+18 -2
View File
@@ -30,6 +30,24 @@ public:
const layer_reuse_cb & reuse,
const layer_share_cb & share);
llama_kv_cache_iswa(
const llama_model & model,
const llama_hparams & hparams,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache_iswa() = default;
//
@@ -73,8 +91,6 @@ public:
llama_kv_cache * get_swa () const;
private:
const llama_hparams & hparams;
const bool unified;
std::unique_ptr<llama_kv_cache> kv_base;
+26 -6
View File
@@ -211,10 +211,12 @@ llama_kv_cache::llama_kv_cache(
n_embd_head_k_all = -1;
}
if (n_embd_head_v_all == 0) {
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
n_embd_head_v_all = -1;
if (!is_mla) {
if (n_embd_head_v_all == 0) {
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
n_embd_head_v_all = -1;
}
}
// [TAG_V_CACHE_VARIABLE]
@@ -336,8 +338,9 @@ llama_kv_cache::llama_kv_cache(
ggml_is_quantized(type_k) &&
hparams.n_embd_head_k() % 64 == 0;
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
// always create Hadamard rotation tensors for DeepSeek lightning indexers
if ((model.arch == LLM_ARCH_DEEPSEEK32 || model.arch == LLM_ARCH_DEEPSEEK4) &&
hparams.n_embd_head_k_full == hparams.indexer_head_size) {
attn_rot_k = true;
}
@@ -1220,6 +1223,23 @@ ggml_type llama_kv_cache::type_v() const {
return layers[0].v->type;
}
std::vector<uint32_t> llama_kv_cache::get_layer_ids() const {
std::vector<uint32_t> res;
res.reserve(layers.size());
for (const auto & layer : layers) {
res.push_back(layer.il);
}
return res;
}
ggml_tensor * llama_kv_cache::get_k_storage(int32_t il) const {
const int32_t ikv = map_layer_ids.at(il);
return layers[ikv].k;
}
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
+3
View File
@@ -161,6 +161,9 @@ public:
ggml_type type_k() const;
ggml_type type_v() const;
std::vector<uint32_t> get_layer_ids() const;
ggml_tensor * get_k_storage(int32_t il) const;
//
// graph_build API
//
+3
View File
@@ -294,6 +294,8 @@ namespace GGUFMeta {
}
template bool llama_model_loader::get_arr_n(enum llm_kv kid, uint32_t & result, bool required);
template std::enable_if<std::is_integral<uint32_t>::value, bool>::type
llama_model_loader::get_arr_n<uint32_t>(const std::string & key, uint32_t & result, bool required);
template<typename T>
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
@@ -395,6 +397,7 @@ namespace GGUFMeta {
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
template bool llama_model_loader::get_arr<std::array<int32_t, 512>>(enum llm_kv kid, std::array<int32_t, 512> & result, bool required);
template bool llama_model_loader::get_arr<std::vector<int32_t>>(enum llm_kv kid, std::vector<int32_t> & result, bool required);
template bool llama_model_loader::get_arr<std::array<uint32_t, LLAMA_MAX_LAYERS>>(enum llm_kv kid, std::array<uint32_t, LLAMA_MAX_LAYERS> & result, bool required);
template<typename T>
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
+28 -1
View File
@@ -11,6 +11,7 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-kv-cache-dsa.h"
#include "llama-kv-cache-dsv4.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
@@ -181,6 +182,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_deepseek2ocr(params);
case LLM_ARCH_DEEPSEEK32:
return new llama_model_deepseek32(params);
case LLM_ARCH_DEEPSEEK4:
return new llama_model_deepseek4(params);
case LLM_ARCH_GLM_DSA:
return new llama_model_glm_dsa(params);
case LLM_ARCH_MISTRAL4:
@@ -817,6 +820,7 @@ static const char * llama_expert_gating_func_name(llama_expert_gating_func_type
switch (type) {
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax";
case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid";
case LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS: return "sqrtsoftplus";
default: return "unknown";
}
}
@@ -2156,7 +2160,24 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
}
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
if (arch == LLM_ARCH_DEEPSEEK4) {
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE);
res = new llama_kv_cache_dsv4(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
filter,
reuse);
} else if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
@@ -2328,6 +2349,11 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
}
int32_t llama_model_n_swa(const llama_model * model) {
// dsv4 kv-cache has SWA but it cannot be used as a rollback because of
// other compression ratios, so we return 0 here
if (model->arch == LLM_ARCH_DEEPSEEK4) {
return 0;
}
return model->hparams.n_swa;
}
@@ -2409,6 +2435,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
case LLM_ARCH_DEEPSEEK32:
case LLM_ARCH_DEEPSEEK4:
case LLM_ARCH_PLM:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
+25
View File
@@ -255,9 +255,11 @@ struct llama_layer {
struct ggml_tensor * wq_b = nullptr;
struct ggml_tensor * wkv_a_mqa = nullptr;
struct ggml_tensor * wkv_b = nullptr;
struct ggml_tensor * wkv = nullptr;
struct ggml_tensor * wk_b = nullptr;
struct ggml_tensor * wv_b = nullptr;
struct ggml_tensor * wqkv_b = nullptr;
struct ggml_tensor * wo_a = nullptr;
struct ggml_tensor * wo_b = nullptr;
struct ggml_tensor * wq_cross = nullptr;
struct ggml_tensor * wk_cross = nullptr;
@@ -333,6 +335,7 @@ struct llama_layer {
struct ggml_tensor * ffn_up_b = nullptr; // b3
struct ggml_tensor * ffn_act = nullptr;
struct ggml_tensor * ffn_exp_probs_b = nullptr;
struct ggml_tensor * ffn_gate_tid2eid = nullptr;
// mamba proj
struct ggml_tensor * ssm_in = nullptr;
@@ -463,6 +466,23 @@ struct llama_layer {
// openai-moe
struct ggml_tensor * attn_sinks = nullptr;
// DeepSeek-V4
struct ggml_tensor * attn_kv_norm = nullptr;
struct ggml_tensor * hc_attn_fn = nullptr;
struct ggml_tensor * hc_attn_base = nullptr;
struct ggml_tensor * hc_attn_scale = nullptr;
struct ggml_tensor * hc_ffn_fn = nullptr;
struct ggml_tensor * hc_ffn_base = nullptr;
struct ggml_tensor * hc_ffn_scale = nullptr;
struct ggml_tensor * attn_comp_wkv = nullptr;
struct ggml_tensor * attn_comp_wgate = nullptr;
struct ggml_tensor * attn_comp_ape = nullptr;
struct ggml_tensor * attn_comp_norm = nullptr;
struct ggml_tensor * indexer_comp_wkv = nullptr;
struct ggml_tensor * indexer_comp_wgate = nullptr;
struct ggml_tensor * indexer_comp_ape = nullptr;
struct ggml_tensor * indexer_comp_norm = nullptr;
// cogvlm
struct ggml_tensor * visexp_attn_wqkv = nullptr;
struct ggml_tensor * visexp_attn_wo = nullptr;
@@ -553,6 +573,11 @@ struct llama_model {
struct ggml_tensor * nextn_proj_pre = nullptr;
struct ggml_tensor * nextn_proj_post = nullptr;
// DeepSeek-V4
struct ggml_tensor * hc_head_fn = nullptr;
struct ggml_tensor * hc_head_base = nullptr;
struct ggml_tensor * hc_head_scale = nullptr;
// classifier
struct ggml_tensor * cls = nullptr;
struct ggml_tensor * cls_b = nullptr;
File diff suppressed because it is too large Load Diff
+115
View File
@@ -1085,6 +1085,121 @@ struct llama_model_deepseek32 : public llama_model_base {
};
struct llama_model_deepseek4 : public llama_model_base {
llama_model_deepseek4(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
void load_arch_tensors(llama_model_loader & ml) override;
struct graph : public llm_graph_context {
graph(const llama_model & model, const llm_graph_params & params);
ggml_tensor * build_hc_pre(
ggml_tensor * x,
ggml_tensor * hc_fn,
ggml_tensor * hc_scale,
ggml_tensor * hc_base,
ggml_tensor ** post,
ggml_tensor ** comb,
int il) const;
ggml_tensor * build_hc_post(
ggml_tensor * x,
ggml_tensor * residual,
ggml_tensor * post,
ggml_tensor * comb,
int il) const;
ggml_tensor * build_hc_head(
ggml_tensor * x,
ggml_tensor * hc_fn,
ggml_tensor * hc_scale,
ggml_tensor * hc_base) const;
ggml_tensor * build_attention(
const llama_model & model,
llm_graph_input_dsv4 * inp_dsv4,
ggml_tensor * cur,
ggml_tensor * inp_pos,
int il) const;
ggml_tensor * build_hca_compressed_kv_from_state(
ggml_tensor * kv_state,
ggml_tensor * score_state,
ggml_tensor * state_read_idxs,
ggml_tensor * comp_pos,
ggml_tensor * norm,
int64_t n_embd_head,
const char * name,
int il) const;
ggml_tensor * build_overlap_compressed_kv_from_state(
ggml_tensor * kv_state,
ggml_tensor * score_state,
ggml_tensor * state_read_idxs,
ggml_tensor * comp_pos,
ggml_tensor * norm,
int64_t ratio,
int64_t n_embd_head,
const char * name,
int il) const;
ggml_tensor * build_lid_top_k(
const llama_model & model,
llm_graph_input_dsv4 * inp_dsv4,
ggml_tensor * qr,
ggml_tensor * cur,
ggml_tensor * inp_pos,
int il) const;
ggml_tensor * build_top_k_mask(
ggml_tensor * kq_mask,
ggml_tensor * top_k,
const char * name,
int il) const;
ggml_tensor * build_csa_lid_attention(
const llama_model & model,
llm_graph_input_dsv4 * inp_dsv4,
llm_graph_input_dsv4_raw * inp_attn,
ggml_tensor * q,
ggml_tensor * kv,
ggml_tensor * qr,
ggml_tensor * cur,
ggml_tensor * inp_pos,
ggml_tensor * sinks,
float kq_scale,
int il) const;
ggml_tensor * build_hca_attention(
llm_graph_input_dsv4 * inp_dsv4,
llm_graph_input_dsv4_raw * inp_attn,
ggml_tensor * q,
ggml_tensor * kv,
ggml_tensor * sinks,
float kq_scale,
int il) const;
ggml_tensor * build_raw_attention(
llm_graph_input_dsv4_raw * inp_attn,
ggml_tensor * q,
ggml_tensor * kv,
ggml_tensor * sinks,
float kq_scale,
int il) const;
ggml_tensor * build_hc_weighted_sum(
ggml_tensor * x,
ggml_tensor * weights) const;
ggml_tensor * build_hc_sinkhorn(
ggml_tensor * comb,
int il) const;
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
struct llama_model_deepseek2ocr : public llama_model_base {
llama_model_deepseek2ocr(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
+2
View File
@@ -121,6 +121,8 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
res->t_layer_inp[il] = inpL;
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
-1
View File
@@ -211,7 +211,6 @@ llama_build_and_test(
peg-parser/test-unicode.cpp
peg-parser/tests.h
)
llama_build_and_test(test-regex-partial.cpp)
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
set(MODEL_NAME "tinyllamas/stories15M-q4_0.gguf")
+1
View File
@@ -7759,6 +7759,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false));
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 70000, 4, 1, false)); // row count > CUDA grid-y limit (65535)
for (ggml_type type : all_types) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v));
+3
View File
@@ -412,6 +412,9 @@ static bool arch_supported(const llm_arch arch) {
if (arch == LLM_ARCH_DEEPSEEK2OCR) {
return false;
}
if (arch == LLM_ARCH_DEEPSEEK4) {
return false;
}
// FIXME some models are segfaulting with WebGPU:
#ifdef GGML_USE_WEBGPU
-288
View File
@@ -1,288 +0,0 @@
// Tests common_regex (esp. its partial final matches support).
#include "common.h"
#include "regex-partial.h"
#include <sstream>
#include <iostream>
#include <optional>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << " Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
struct test_case {
std::string pattern;
struct input_output {
std::string input;
common_regex_match output;
};
std::vector<input_output> inputs_outputs;
};
static std::string common_regex_match_type_name(common_regex_match_type type) {
switch (type) {
case COMMON_REGEX_MATCH_TYPE_NONE:
return "COMMON_REGEX_MATCH_TYPE_NONE";
case COMMON_REGEX_MATCH_TYPE_PARTIAL:
return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
case COMMON_REGEX_MATCH_TYPE_FULL:
return "COMMON_REGEX_MATCH_TYPE_FULL";
}
return "?";
}
static void test_regex() {
printf("[%s]\n", __func__);
auto test = [](const test_case & test_case) {
common_regex cr(test_case.pattern);
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
for (const auto & input_output : test_case.inputs_outputs) {
std::cout << " Input: " << input_output.input << '\n';
auto m = cr.search(input_output.input, 0);
if (m != input_output.output) {
auto match_to_str = [&](const std::optional<common_regex_match> & m) {
std::ostringstream ss;
if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
ss << "<no match>";
} else {
GGML_ASSERT(!input_output.output.groups.empty());
std::vector<std::string> parts;
for (const auto & g : m->groups) {
parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
}
ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
}
return ss.str();
};
std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
std::cout << " Got: " << match_to_str(m) << '\n';
std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
throw std::runtime_error("Test failed");
}
}
};
test({
"a",
{
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
}
});
test({
"abcd",
{
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"d", {}},
{"bcd", {}},
{"cde", {}},
{"cd", {}},
{"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
{"abbie", {}},
{"", {}},
}
});
test({
".*?ab",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
}
});
test({
"a.*?b",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"d", {}},
{"b", {}},
}
});
test({
"ab(?:cd){2,4}ef",
{
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abcde", {}},
{"abcdef", {}},
{"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
{"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
{"abcdcdcdcdcdef", {}},
{"abcde", {}},
{"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
}
});
test({
"a(?:rte| pure )fact",
{
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"fact", {}},
{"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
{"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
{"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
{"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
{"" , {}},
{"pure", {}},
{"pure fact", {}},
}
});
test({
"abc",
{
{" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"b", {}},
{"c", {}},
{"", {}},
}
});
test({
"(?:abc)?\\s*def",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
{" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
}
});
test({
"a+b",
{
{"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
}
});
test({
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">", // match 5 (function name again)
{
{"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
{"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
{"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
{"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
{"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
{"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
{"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
{"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
{"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
{"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
}
});
}
static void test_regex_to_reversed_partial_regex() {
printf("[%s]\n", __func__);
assert_equals<std::string>(
"^((?:(?:c)?b)?a)",
regex_to_reversed_partial_regex("abc"));
assert_equals<std::string>(
"^(a+)",
regex_to_reversed_partial_regex("a+"));
assert_equals<std::string>(
"^(a*)",
regex_to_reversed_partial_regex("a*"));
assert_equals<std::string>(
"^(a?)",
regex_to_reversed_partial_regex("a?"));
assert_equals<std::string>(
"^([a-z])",
regex_to_reversed_partial_regex("[a-z]"));
assert_equals<std::string>(
"^((?:\\w+)?[a-z])",
regex_to_reversed_partial_regex("[a-z]\\w+"));
assert_equals<std::string>(
"^((?:a|b))",
regex_to_reversed_partial_regex("(?:a|b)"));
assert_equals<std::string>(
"^((?:(?:(?:d)?c)?b)?a)",
regex_to_reversed_partial_regex("abcd"));
assert_equals<std::string>(
"^((?:b)?a*)", // TODO: ((?:b)?a*+).* ??
regex_to_reversed_partial_regex("a*b"));
assert_equals<std::string>(
"^((?:(?:b)?a)?.*)",
regex_to_reversed_partial_regex(".*?ab"));
assert_equals<std::string>(
"^((?:(?:b)?.*)?a)",
regex_to_reversed_partial_regex("a.*?b"));
assert_equals<std::string>(
"^((?:(?:d)?(?:(?:c)?b))?a)",
regex_to_reversed_partial_regex("a(bc)d"));
assert_equals<std::string>(
"^((?:(?:(?:c)?b|(?:e)?d))?a)",
regex_to_reversed_partial_regex("a(bc|de)"));
assert_equals<std::string>(
"^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)",
regex_to_reversed_partial_regex("ab{2,4}c"));
}
int main() {
test_regex_to_reversed_partial_regex();
test_regex();
std::cout << "All tests passed.\n";
}
+13
View File
@@ -1538,6 +1538,19 @@ private:
/* media_path */ params_base.media_path,
/* force_pure_content */ params_base.force_pure_content_parser
};
{
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
auto it = params_base.default_template_kwargs.find("preserve_reasoning");
bool supported = caps.at("supports_preserve_reasoning");
bool enabled = it != params_base.default_template_kwargs.end();
if (supported && !enabled) {
SRV_INF("%s", "chat template supports preserving reasoning, consider enabling it via --reasoning-preserve\n");
}
if (!supported && enabled) {
SRV_WRN("%s", "chat template does NOT support preserving reasoning, --reasoning-preserve has no effect\n");
}
}
}
return true;
+1 -1
View File
@@ -39,7 +39,7 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme);
}
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), common_http_format_host(parsed_url.host).c_str(), parsed_url.port, parsed_url.path.c_str());
std::map<std::string, std::string> headers;
const std::string proxy_header_prefix = "x-llama-server-proxy-header-";
+2 -1
View File
@@ -1,4 +1,5 @@
#include "common.h"
#include "http.h"
#include "server-http.h"
#include "server-stream.h"
#include "server-common.h"
@@ -441,7 +442,7 @@ bool server_http_context::start() {
srv->wait_until_ready();
listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
: string_format("%s://%s:%d", is_ssl ? "https" : "http", hostname.c_str(), port);
: string_format("%s://%s:%d", is_ssl ? "https" : "http", common_http_format_host(hostname).c_str(), port);
return true;
}
+3 -1
View File
@@ -1,4 +1,5 @@
#include "server-common.h"
#include "http.h"
#include "server-models.h"
#include "server-context.h"
#include "server-stream.h"
@@ -2263,7 +2264,8 @@ server_http_proxy::server_http_proxy(
}
if (lowered == "host") {
bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
const std::string url_host = common_http_format_host(host);
req.set_header(key, is_default_port ? url_host : url_host + ":" + std::to_string(port));
} else {
req.set_header(key, value);
}
+1
View File
@@ -8,6 +8,7 @@ set(UI_SOURCE_GLOBS
set(UI_SOURCE_FILES
package.json
package-lock.json
src/.gitignore
vite.config.ts
svelte.config.js
tsconfig.json

Some files were not shown because too many files have changed in this diff Show More