mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-29 00:57:39 +02:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e489a5ca0e | |||
| e21cdc11a0 | |||
| e974923698 | |||
| 1c0d9081fd | |||
| a8bad3842e | |||
| 75f3bc94e6 | |||
| aa00911d12 | |||
| ce8fd4b1a6 | |||
| 9f5e1edb10 | |||
| 920b3e78cb | |||
| 974c8c94cc | |||
| 227ed28e12 | |||
| bafae27654 | |||
| 873c825611 | |||
| 82764d8f40 | |||
| 21a4933042 | |||
| 1e9d771e2c | |||
| aa4695c5e5 | |||
| 547765a93e | |||
| 9e209c5aee | |||
| 6313acbef0 | |||
| ff5ef82786 | |||
| 073bb2c20b | |||
| af1127d3c4 | |||
| 865ff06b2f | |||
| 2b2cd57de6 | |||
| 660386f6f8 | |||
| a29e4c0b7b | |||
| b136b62cf9 | |||
| 81069a808a | |||
| 9aa2807769 | |||
| 3fc65063d9 | |||
| 05b3caaa48 | |||
| e62fa13c24 | |||
| bfd1f453cb | |||
| e4fed9d08d | |||
| 5dd102539b | |||
| fb38d6f278 |
@@ -17,7 +17,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap,security"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
set( CMAKE_SYSTEM_NAME Linux )
|
||||
set( CMAKE_SYSTEM_PROCESSOR arm64 )
|
||||
|
||||
set( target aarch64-linux-gnu )
|
||||
|
||||
set( CMAKE_C_COMPILER clang )
|
||||
set( CMAKE_CXX_COMPILER clang++ )
|
||||
|
||||
set( CMAKE_C_COMPILER_TARGET ${target} )
|
||||
set( CMAKE_CXX_COMPILER_TARGET ${target} )
|
||||
|
||||
set( arch_c_flags "-march=armv8.7-a -fvectorize -ffp-model=fast -fno-finite-math-only" )
|
||||
set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function -Wno-gnu-zero-variadic-macro-arguments" )
|
||||
|
||||
set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
||||
set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
||||
|
||||
+10
-7
@@ -291,14 +291,16 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
|
||||
hf_tag = "default";
|
||||
}
|
||||
|
||||
const bool offline = params.offline;
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
|
||||
|
||||
// prepare local path for caching
|
||||
auto preset_fname = clean_file_name(hf_repo + "_preset.ini");
|
||||
auto preset_path = fs_get_cache_file(preset_fname);
|
||||
const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline);
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.offline = params.offline;
|
||||
const int status = common_download_file_single(preset_url, preset_path, opts);
|
||||
const bool has_preset = status >= 200 && status < 400;
|
||||
|
||||
// remote preset is optional, so we don't error out if not found
|
||||
@@ -341,10 +343,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
model.hf_file = model.path;
|
||||
model.path = "";
|
||||
}
|
||||
common_download_model_opts opts;
|
||||
opts.download_mmproj = true;
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = bearer_token;
|
||||
opts.offline = offline;
|
||||
auto download_result = common_download_model(model, bearer_token, opts);
|
||||
auto download_result = common_download_model(model, opts, true);
|
||||
|
||||
if (download_result.model_path.empty()) {
|
||||
LOG_ERR("error: failed to download model from Hugging Face\n");
|
||||
@@ -365,9 +367,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
|
||||
common_download_model_opts opts;
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = bearer_token;
|
||||
opts.offline = offline;
|
||||
auto download_result = common_download_model(model, bearer_token, opts);
|
||||
auto download_result = common_download_model(model, opts);
|
||||
if (download_result.model_path.empty()) {
|
||||
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
|
||||
exit(1);
|
||||
|
||||
@@ -69,6 +69,10 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
|
||||
+216
-9
@@ -865,9 +865,10 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
adjusted_messages.push_back(adjusted);
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "[THINK]";
|
||||
@@ -887,7 +888,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps();
|
||||
|
||||
// Response format parser
|
||||
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
|
||||
if (has_response_format) {
|
||||
// Ministral wants to emit json surrounded by code fences
|
||||
return generation_prompt + (reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```");
|
||||
}
|
||||
@@ -928,6 +929,10 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
@@ -1063,6 +1068,10 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
@@ -1082,6 +1091,14 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
|
||||
if (inputs.add_generation_prompt && string_ends_with(data.prompt, "<turn|>\n")) {
|
||||
// This may happen if the model generates content + tool_call, the
|
||||
// template does not add the model's next turn and confuses the model
|
||||
// from emitting its proper reasoning token sequence.
|
||||
data.prompt += "<|turn>model\n";
|
||||
}
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<|channel>thought";
|
||||
@@ -1109,7 +1126,8 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
p.rule("thought", p.content(p.literal("<|channel>thought") + p.space() + p.until("<channel|>") + p.literal("<channel|>")));
|
||||
}
|
||||
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
auto consume_empty_channels = p.gbnf(p.zero_or_more(p.literal("<|channel>") + p.negate(p.literal("thought"))), "");
|
||||
auto thought = (p.peek(p.literal("<|channel>")) + consume_empty_channels + p.ref("thought")) | p.negate(p.literal("<|channel>"));
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.literal("```json") <<
|
||||
@@ -1173,12 +1191,16 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
/* max = */ inputs.parallel_tool_calls ? -1 : 1
|
||||
));
|
||||
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<|tool_call>"})));
|
||||
auto scan_to_toolcall = p.rule("scan-to-toolcall", p.until("<|tool_call>"));
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>", "<|tool_call>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.zero_or_more(message) + tool_call;
|
||||
return start + p.zero_or_more(message) + scan_to_toolcall + tool_call;
|
||||
}
|
||||
|
||||
auto content = p.rule("content", p.content(p.until("<|channel>")));
|
||||
// Gemma 4 may emit an extra <|channel>thought\n<channel|> at the end of the content. It may
|
||||
// also emit a single trailing <channel|> token. Consume all complete reasoning blocks and
|
||||
// then stop at the first unmatched <channel|> token.
|
||||
auto content = p.rule("content", p.content(p.until_one_of({"<|channel>", "<channel|>"})));
|
||||
auto message = p.rule("message", thought + content);
|
||||
return start + p.one_or_more(message);
|
||||
});
|
||||
@@ -1193,6 +1215,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
@@ -1643,6 +1669,173 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_deepseek_v3_2(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
"|DSML|",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
const std::string DSML = "|DSML|";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
const std::string FC_START = "<" + DSML + "function_calls>";
|
||||
const std::string FC_END = "</" + DSML + "function_calls>";
|
||||
const std::string INVOKE_START = "<" + DSML + "invoke";
|
||||
const std::string INVOKE_END = "</" + DSML + "invoke>";
|
||||
const std::string PARAM_START = "<" + DSML + "parameter";
|
||||
const std::string PARAM_END = "</" + DSML + "parameter>";
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning && inputs.enable_thinking) {
|
||||
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
|
||||
} else if (extract_reasoning) {
|
||||
// Thinking disabled but reasoning extraction requested: the generation prompt
|
||||
// contains an empty <think></think> pair that must still be consumed.
|
||||
reasoning = p.optional(p.literal(THINK_START) + p.until(THINK_END) + p.literal(THINK_END));
|
||||
}
|
||||
|
||||
if (has_response_format) {
|
||||
auto response_format = p.rule("response-format",
|
||||
p.literal("```json") + p.space() +
|
||||
p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)) +
|
||||
p.space() + p.literal("```"));
|
||||
return generation_prompt + reasoning + response_format + end;
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_choice = p.choice();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
const auto & props = params.contains("properties") ? params.at("properties") : json::object();
|
||||
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required")) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
auto schema_info = common_schema_info();
|
||||
schema_info.resolve_refs(params);
|
||||
|
||||
std::vector<common_peg_parser> required_parsers;
|
||||
std::vector<common_peg_parser> optional_parsers;
|
||||
for (const auto & [param_name, param_schema] : props.items()) {
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
bool is_string = schema_info.resolves_to_string(param_schema);
|
||||
|
||||
auto arg = p.tool_arg(
|
||||
p.tool_arg_open(
|
||||
p.literal(PARAM_START + " name=\"") +
|
||||
p.tool_arg_name(p.literal(param_name)) +
|
||||
p.literal("\" string=\"" + std::string(is_string ? "true" : "false") + "\">")) +
|
||||
(is_string
|
||||
? p.tool_arg_string_value(p.until(PARAM_END))
|
||||
: p.tool_arg_json_value(p.schema(p.json(),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, false))) +
|
||||
p.tool_arg_close(p.literal(PARAM_END)));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
required_parsers.push_back(named_arg);
|
||||
} else {
|
||||
optional_parsers.push_back(named_arg);
|
||||
}
|
||||
}
|
||||
|
||||
common_peg_parser args_seq = p.eps();
|
||||
for (size_t i = 0; i < required_parsers.size(); i++) {
|
||||
if (i > 0) {
|
||||
args_seq = args_seq + p.space();
|
||||
}
|
||||
args_seq = args_seq + required_parsers[i];
|
||||
}
|
||||
|
||||
if (!optional_parsers.empty()) {
|
||||
common_peg_parser any_opt = p.choice();
|
||||
for (const auto & opt : optional_parsers) {
|
||||
any_opt |= opt;
|
||||
}
|
||||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, -1);
|
||||
}
|
||||
|
||||
common_peg_parser invoke_body = args_seq;
|
||||
auto func_parser = p.tool(
|
||||
p.tool_open(p.literal(INVOKE_START + " name=\"") +
|
||||
p.tool_name(p.literal(name)) + p.literal("\">\n")) +
|
||||
invoke_body + p.space() +
|
||||
p.tool_close(p.literal(INVOKE_END)));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_calls = p.eps();
|
||||
if (inputs.parallel_tool_calls) {
|
||||
tool_calls = p.trigger_rule("tool-call",
|
||||
p.literal(FC_START) + p.space() + tool_choice +
|
||||
p.zero_or_more(p.space() + tool_choice) + p.space() + p.literal(FC_END));
|
||||
} else {
|
||||
tool_calls = p.trigger_rule("tool-call",
|
||||
p.literal(FC_START) + p.space() + tool_choice + p.space() + p.literal(FC_END));
|
||||
}
|
||||
|
||||
if (!require_tools) {
|
||||
tool_calls = p.optional(tool_calls);
|
||||
}
|
||||
|
||||
auto content_before_tools = p.content(p.until(FC_START));
|
||||
return generation_prompt + reasoning + content_before_tools + tool_calls + end;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, FC_START },
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
namespace workaround {
|
||||
|
||||
static void map_developer_role_to_system(json & messages) {
|
||||
@@ -1914,9 +2107,23 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return common_chat_params_init_gigachat_v3(tmpl, params);
|
||||
}
|
||||
|
||||
// DeepSeek V3.2 format detection: template defines dsml_token and uses it for tool calls.
|
||||
// The template source contains the token as a variable assignment, not as a literal in markup.
|
||||
if (src.find("dsml_token") != std::string::npos &&
|
||||
src.find("function_calls") != std::string::npos &&
|
||||
src.find("DSML") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: DeepSeek V3.2\n");
|
||||
return common_chat_params_init_deepseek_v3_2(tmpl, params);
|
||||
}
|
||||
|
||||
// Gemma4 format detection
|
||||
if (src.find("'<|tool_call>call:'") != std::string::npos) {
|
||||
workaround::convert_tool_responses_gemma4(params.messages);
|
||||
if (src.find("{#- OpenAI Chat Completions:") == std::string::npos) {
|
||||
// apply workarounds if using the older gemma4 templates
|
||||
LOG_WRN("%s: detected an outdated gemma4 chat template, applying compatibility workarounds. "
|
||||
"Consider updating to the official template.\n", __func__);
|
||||
workaround::convert_tool_responses_gemma4(params.messages);
|
||||
}
|
||||
return common_chat_params_init_gemma4(tmpl, params);
|
||||
}
|
||||
|
||||
|
||||
+112
-68
@@ -114,7 +114,7 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
|
||||
return {hf_repo, tag};
|
||||
}
|
||||
|
||||
class ProgressBar {
|
||||
class ProgressBar : public common_download_callback {
|
||||
static inline std::mutex mutex;
|
||||
static inline std::map<const ProgressBar *, int> lines;
|
||||
static inline int max_line = 0;
|
||||
@@ -138,7 +138,11 @@ class ProgressBar {
|
||||
}
|
||||
|
||||
public:
|
||||
ProgressBar(const std::string & url = "") : filename(url) {
|
||||
ProgressBar() = default;
|
||||
|
||||
void on_start(const common_download_progress & p) override {
|
||||
filename = p.url;
|
||||
|
||||
if (auto pos = filename.rfind('/'); pos != std::string::npos) {
|
||||
filename = filename.substr(pos + 1);
|
||||
}
|
||||
@@ -156,13 +160,13 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
~ProgressBar() {
|
||||
void on_done(const common_download_progress &, bool) override {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
cleanup(this);
|
||||
}
|
||||
|
||||
void update(size_t current, size_t total) {
|
||||
if (!total || !is_output_a_tty()) {
|
||||
void on_update(const common_download_progress & p) override {
|
||||
if (!p.total || !is_output_a_tty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -175,8 +179,8 @@ public:
|
||||
int lines_up = max_line - lines[this];
|
||||
|
||||
size_t bar = (55 - len) * 2;
|
||||
size_t pct = (100 * current) / total;
|
||||
size_t pos = (bar * current) / total;
|
||||
size_t pct = (100 * p.downloaded) / p.total;
|
||||
size_t pos = (bar * p.downloaded) / p.total;
|
||||
|
||||
if (lines_up > 0) {
|
||||
std::cout << "\033[" << lines_up << "A";
|
||||
@@ -193,7 +197,7 @@ public:
|
||||
}
|
||||
std::cout << '\r' << std::flush;
|
||||
|
||||
if (current == total) {
|
||||
if (p.downloaded == p.total) {
|
||||
cleanup(this);
|
||||
}
|
||||
}
|
||||
@@ -206,8 +210,8 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
const std::string & resolve_path,
|
||||
const std::string & path_tmp,
|
||||
bool supports_ranges,
|
||||
size_t existing_size,
|
||||
size_t & total_size) {
|
||||
common_download_progress & p,
|
||||
common_download_callback * callback) {
|
||||
std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
|
||||
if (!ofs.is_open()) {
|
||||
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
|
||||
@@ -215,29 +219,27 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
}
|
||||
|
||||
httplib::Headers headers;
|
||||
if (supports_ranges && existing_size > 0) {
|
||||
headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
|
||||
if (supports_ranges && p.downloaded > 0) {
|
||||
headers.emplace("Range", "bytes=" + std::to_string(p.downloaded) + "-");
|
||||
}
|
||||
|
||||
const char * func = __func__; // avoid __func__ inside a lambda
|
||||
size_t downloaded = existing_size;
|
||||
size_t progress_step = 0;
|
||||
ProgressBar bar(resolve_path);
|
||||
|
||||
auto res = cli.Get(resolve_path, headers,
|
||||
[&](const httplib::Response &response) {
|
||||
if (existing_size > 0 && response.status != 206) {
|
||||
if (p.downloaded > 0 && response.status != 206) {
|
||||
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (existing_size == 0 && response.status != 200) {
|
||||
if (p.downloaded == 0 && response.status != 200) {
|
||||
LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
|
||||
return false;
|
||||
}
|
||||
if (total_size == 0 && response.has_header("Content-Length")) {
|
||||
if (p.total == 0 && response.has_header("Content-Length")) {
|
||||
try {
|
||||
size_t content_length = std::stoull(response.get_header_value("Content-Length"));
|
||||
total_size = existing_size + content_length;
|
||||
p.total = p.downloaded + content_length;
|
||||
} catch (const std::exception &e) {
|
||||
LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
|
||||
}
|
||||
@@ -250,11 +252,16 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
|
||||
return false;
|
||||
}
|
||||
downloaded += len;
|
||||
p.downloaded += len;
|
||||
progress_step += len;
|
||||
|
||||
if (progress_step >= total_size / 1000 || downloaded == total_size) {
|
||||
bar.update(downloaded, total_size);
|
||||
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
|
||||
if (callback) {
|
||||
callback->on_update(p);
|
||||
if (callback->is_cancelled()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
progress_step = 0;
|
||||
}
|
||||
return true;
|
||||
@@ -275,28 +282,13 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
// returns status code or -1 on error
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers,
|
||||
bool skip_etag = false) {
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const common_download_opts & opts,
|
||||
bool skip_etag) {
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : custom_headers) {
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
if (!bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + bearer_token);
|
||||
}
|
||||
cli.set_default_headers(headers);
|
||||
|
||||
const bool file_exists = std::filesystem::exists(path);
|
||||
|
||||
if (file_exists && skip_etag) {
|
||||
@@ -304,6 +296,20 @@ static int common_download_file_single_online(const std::string & url,
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : opts.headers) {
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
if (!opts.bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + opts.bearer_token);
|
||||
}
|
||||
cli.set_default_headers(headers);
|
||||
|
||||
std::string last_etag;
|
||||
if (file_exists) {
|
||||
last_etag = read_etag(path);
|
||||
@@ -326,10 +332,11 @@ static int common_download_file_single_online(const std::string & url,
|
||||
etag = head->get_header_value("ETag");
|
||||
}
|
||||
|
||||
size_t total_size = 0;
|
||||
common_download_progress p;
|
||||
p.url = url;
|
||||
if (head->has_header("Content-Length")) {
|
||||
try {
|
||||
total_size = std::stoull(head->get_header_value("Content-Length"));
|
||||
p.total = std::stoull(head->get_header_value("Content-Length"));
|
||||
} catch (const std::exception& e) {
|
||||
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
|
||||
}
|
||||
@@ -357,14 +364,21 @@ static int common_download_file_single_online(const std::string & url,
|
||||
|
||||
{ // silent
|
||||
std::error_code ec;
|
||||
std::filesystem::path p(path);
|
||||
std::filesystem::create_directories(p.parent_path(), ec);
|
||||
std::filesystem::create_directories(std::filesystem::path(path).parent_path(), ec);
|
||||
}
|
||||
|
||||
bool success = false;
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
int delay = retry_delay_seconds;
|
||||
|
||||
if (opts.callback) {
|
||||
opts.callback->on_start(p);
|
||||
}
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
if (opts.callback && opts.callback->is_cancelled()) {
|
||||
break;
|
||||
}
|
||||
if (i) {
|
||||
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(delay));
|
||||
@@ -378,28 +392,44 @@ static int common_download_file_single_online(const std::string & url,
|
||||
existing_size = std::filesystem::file_size(path_temporary);
|
||||
} else if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
return -1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
p.downloaded = existing_size;
|
||||
|
||||
LOG_DBG("%s: downloading from %s to %s (etag:%s)...\n",
|
||||
__func__, common_http_show_masked_url(parts).c_str(),
|
||||
path_temporary.c_str(), etag.c_str());
|
||||
|
||||
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
|
||||
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, p, opts.callback)) {
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return -1;
|
||||
break;
|
||||
}
|
||||
if (!etag.empty() && !skip_etag) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
return head->status;
|
||||
success = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
if (opts.callback) {
|
||||
opts.callback->on_done(p, success);
|
||||
}
|
||||
if (opts.callback && opts.callback->is_cancelled() &&
|
||||
std::filesystem::exists(path_temporary)) {
|
||||
if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, path_temporary.c_str());
|
||||
}
|
||||
}
|
||||
if (!success) {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
}
|
||||
|
||||
return head->status;
|
||||
}
|
||||
|
||||
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
|
||||
@@ -438,12 +468,15 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
|
||||
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers,
|
||||
const common_download_opts & opts,
|
||||
bool skip_etag) {
|
||||
if (!offline) {
|
||||
return common_download_file_single_online(url, path, bearer_token, headers, skip_etag);
|
||||
if (!opts.offline) {
|
||||
ProgressBar tty_cb;
|
||||
common_download_opts online_opts = opts;
|
||||
if (!online_opts.callback) {
|
||||
online_opts.callback = &tty_cb;
|
||||
}
|
||||
return common_download_file_single_online(url, path, online_opts, skip_etag);
|
||||
}
|
||||
|
||||
if (!std::filesystem::exists(path)) {
|
||||
@@ -452,6 +485,16 @@ int common_download_file_single(const std::string & url,
|
||||
}
|
||||
|
||||
LOG_DBG("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
|
||||
|
||||
// notify the callback that the file was cached
|
||||
if (opts.callback) {
|
||||
common_download_progress p;
|
||||
p.url = url;
|
||||
p.cached = true;
|
||||
opts.callback->on_start(p);
|
||||
opts.callback->on_done(p, true);
|
||||
}
|
||||
|
||||
return 304; // Not Modified - fake cached response
|
||||
}
|
||||
|
||||
@@ -631,16 +674,16 @@ struct hf_plan {
|
||||
hf_cache::hf_file mmproj;
|
||||
};
|
||||
|
||||
static hf_plan get_hf_plan(const common_params_model & model,
|
||||
const std::string & token,
|
||||
const common_download_model_opts & opts) {
|
||||
static hf_plan get_hf_plan(const common_params_model & model,
|
||||
const common_download_opts & opts,
|
||||
bool download_mmproj) {
|
||||
hf_plan plan;
|
||||
hf_cache::hf_files all;
|
||||
|
||||
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
|
||||
|
||||
if (!opts.offline) {
|
||||
all = hf_cache::get_repo_files(repo, token);
|
||||
all = hf_cache::get_repo_files(repo, opts.bearer_token);
|
||||
}
|
||||
if (all.empty()) {
|
||||
all = hf_cache::get_cached_files(repo);
|
||||
@@ -675,7 +718,7 @@ static hf_plan get_hf_plan(const common_params_model & model,
|
||||
plan.primary = primary;
|
||||
plan.model_files = get_split_files(all, primary);
|
||||
|
||||
if (opts.download_mmproj) {
|
||||
if (download_mmproj) {
|
||||
plan.mmproj = find_best_mmproj(all, primary.path);
|
||||
}
|
||||
|
||||
@@ -710,10 +753,9 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
|
||||
return tasks;
|
||||
}
|
||||
|
||||
common_download_model_result common_download_model(const common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
const common_download_model_opts & opts,
|
||||
const common_header_list & headers) {
|
||||
common_download_model_result common_download_model(const common_params_model & model,
|
||||
const common_download_opts & opts,
|
||||
bool download_mmproj) {
|
||||
common_download_model_result result;
|
||||
std::vector<download_task> tasks;
|
||||
hf_plan hf;
|
||||
@@ -721,7 +763,7 @@ common_download_model_result common_download_model(const common_params_model
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
hf = get_hf_plan(model, bearer_token, opts);
|
||||
hf = get_hf_plan(model, opts, download_mmproj);
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
@@ -742,8 +784,8 @@ common_download_model_result common_download_model(const common_params_model
|
||||
std::vector<std::future<bool>> futures;
|
||||
for (const auto & task : tasks) {
|
||||
futures.push_back(std::async(std::launch::async,
|
||||
[&task, &bearer_token, offline = opts.offline, &headers, is_hf]() {
|
||||
int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf);
|
||||
[&task, &opts, is_hf]() {
|
||||
int status = common_download_file_single(task.url, task.path, opts, is_hf);
|
||||
return is_http_status_ok(status);
|
||||
}
|
||||
));
|
||||
@@ -879,7 +921,9 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
std::string local_path = fs_get_cache_file(model_filename);
|
||||
|
||||
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
|
||||
const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = token;
|
||||
const int http_status = common_download_file_single(blob_url, local_path, opts);
|
||||
if (!is_http_status_ok(http_status)) {
|
||||
throw std::runtime_error("Failed to download Docker Model");
|
||||
}
|
||||
|
||||
+25
-10
@@ -8,6 +8,22 @@ struct common_params_model;
|
||||
using common_header = std::pair<std::string, std::string>;
|
||||
using common_header_list = std::vector<common_header>;
|
||||
|
||||
struct common_download_progress {
|
||||
std::string url;
|
||||
size_t downloaded = 0;
|
||||
size_t total = 0;
|
||||
bool cached = false;
|
||||
};
|
||||
|
||||
class common_download_callback {
|
||||
public:
|
||||
virtual ~common_download_callback() = default;
|
||||
virtual void on_start(const common_download_progress & p) = 0;
|
||||
virtual void on_update(const common_download_progress & p) = 0;
|
||||
virtual void on_done(const common_download_progress & p, bool ok) = 0;
|
||||
virtual bool is_cancelled() const { return false; }
|
||||
};
|
||||
|
||||
struct common_remote_params {
|
||||
common_header_list headers;
|
||||
long timeout = 0; // in seconds, 0 means no timeout
|
||||
@@ -31,10 +47,12 @@ struct common_cached_model_info {
|
||||
}
|
||||
};
|
||||
|
||||
// Options for common_download_model
|
||||
struct common_download_model_opts {
|
||||
bool download_mmproj = false;
|
||||
bool offline = false;
|
||||
// Options for common_download_model and common_download_file_single
|
||||
struct common_download_opts {
|
||||
std::string bearer_token;
|
||||
common_header_list headers;
|
||||
bool offline = false;
|
||||
common_download_callback * callback = nullptr;
|
||||
};
|
||||
|
||||
// Result of common_download_model
|
||||
@@ -69,9 +87,8 @@ struct common_download_model_result {
|
||||
// returns result with model_path and mmproj_path (empty on failure)
|
||||
common_download_model_result common_download_model(
|
||||
const common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
const common_download_model_opts & opts = {},
|
||||
const common_header_list & headers = {}
|
||||
const common_download_opts & opts = {},
|
||||
bool download_mmproj = false
|
||||
);
|
||||
|
||||
// returns list of cached models
|
||||
@@ -82,9 +99,7 @@ std::vector<common_cached_model_info> common_list_cached_models();
|
||||
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
bool offline,
|
||||
const common_header_list & headers = {},
|
||||
const common_download_opts & opts = {},
|
||||
bool skip_etag = false);
|
||||
|
||||
// resolve and download model from Docker registry
|
||||
|
||||
+27
-2
@@ -890,6 +890,10 @@ struct parser_executor {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
};
|
||||
|
||||
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
|
||||
@@ -957,7 +961,8 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_and_parser> ||
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
@@ -1036,6 +1041,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "Not(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1565,6 +1572,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_schema_parser>) {
|
||||
visit(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
@@ -1651,10 +1659,13 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
|
||||
std::string s;
|
||||
for (const auto & child : p.children) {
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
if (child_gbnf.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (!s.empty()) {
|
||||
s += " ";
|
||||
}
|
||||
auto child_gbnf = to_gbnf(child);
|
||||
const auto & child_parser = effective_parser(child);
|
||||
if (std::holds_alternative<common_peg_choice_parser>(child_parser) ||
|
||||
std::holds_alternative<common_peg_sequence_parser>(child_parser)) {
|
||||
@@ -1754,6 +1765,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_atomic_parser>) {
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return p.grammar;
|
||||
} else {
|
||||
static_assert(is_always_false_v<T>);
|
||||
}
|
||||
@@ -1888,6 +1901,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
{"child", p.child},
|
||||
{"tag", p.tag}
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
|
||||
}
|
||||
}, variant);
|
||||
}
|
||||
@@ -2050,6 +2065,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
};
|
||||
}
|
||||
|
||||
if (type == "gbnf") {
|
||||
if (!j.contains("child") || !j.contains("grammar")) {
|
||||
throw std::runtime_error("gbnf parser missing required fields");
|
||||
}
|
||||
return common_peg_gbnf_parser{
|
||||
j["child"].get<common_peg_parser_id>(),
|
||||
j["grammar"].get<std::string>(),
|
||||
};
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown parser type: " + type);
|
||||
}
|
||||
|
||||
|
||||
+11
-1
@@ -270,6 +270,11 @@ struct common_peg_tag_parser {
|
||||
std::string tag;
|
||||
};
|
||||
|
||||
struct common_peg_gbnf_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
@@ -290,7 +295,8 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_rule_parser,
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser
|
||||
common_peg_tag_parser,
|
||||
common_peg_gbnf_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
@@ -504,6 +510,10 @@ class common_peg_parser_builder {
|
||||
// Unlike rules, you can tag multiple nodes with the same tag.
|
||||
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
|
||||
|
||||
// Wraps a child parser but emits a custom GBNF grammar string instead of
|
||||
// the child's grammar. Parsing delegates entirely to the child.
|
||||
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
|
||||
+203
-19
@@ -4258,9 +4258,7 @@ class Qwen2VLVisionModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2_5OmniModel")
|
||||
class Qwen25OmniModel(Qwen2VLVisionModel):
|
||||
has_vision_encoder = True
|
||||
class Qwen25AudioModel(MmprojModel):
|
||||
has_audio_encoder = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -4276,12 +4274,6 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
|
||||
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"])
|
||||
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5))
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("vision_config")
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("audio_config")
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# SinusoidsPositionEmbedding
|
||||
assert self.hparams_audio is not None
|
||||
@@ -4312,7 +4304,32 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
|
||||
# this tensor is left unused in transformers code
|
||||
# https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809
|
||||
return
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
yield from MmprojModel.modify_tensors(self, data_torch, name, bid)
|
||||
|
||||
return # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2_5OmniModel")
|
||||
class Qwen25OmniModel(Qwen2VLVisionModel, Qwen25AudioModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = True
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("vision_config")
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config["thinker_config"].get("audio_config")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "visual." in name:
|
||||
yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid)
|
||||
elif "audio_tower." in name:
|
||||
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
|
||||
return # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("InternVisionModel")
|
||||
@@ -4816,7 +4833,10 @@ class RND1Model(Qwen2MoeModel):
|
||||
class Qwen3VLVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
if self.hparams_vision is None:
|
||||
logger.info("No vision config found, skipping vision tensor processing")
|
||||
return
|
||||
|
||||
# Compute image_size if not present
|
||||
if "image_size" not in self.hparams_vision:
|
||||
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
|
||||
@@ -4837,7 +4857,9 @@ class Qwen3VLVisionModel(MmprojModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
# in case mixed modalities, the arch will be handled by subclass
|
||||
if not self.has_audio_encoder:
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
|
||||
if self.hparams_vision is not None:
|
||||
@@ -4925,11 +4947,64 @@ class Qwen3VLVisionModel(MmprojModel):
|
||||
return
|
||||
|
||||
if name.startswith("visual."):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
return
|
||||
yield from MmprojModel.modify_tensors(self, data_torch, name, bid)
|
||||
return # skip other tensors
|
||||
|
||||
# Fall back to parent class for other tensors
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
|
||||
class Qwen3OmniMmprojModel(Qwen3VLVisionModel, Qwen25AudioModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = True
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
if self.has_vision_encoder:
|
||||
return self.global_config["thinker_config"].get("vision_config")
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
if self.has_audio_encoder:
|
||||
return self.global_config["thinker_config"].get("audio_config")
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if self.has_vision_encoder:
|
||||
Qwen3VLVisionModel.set_gguf_parameters(self)
|
||||
self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.QWEN3VL)
|
||||
if self.has_audio_encoder:
|
||||
Qwen25AudioModel.set_gguf_parameters(self)
|
||||
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.QWEN3A)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "visual." in name:
|
||||
if not self.has_vision_encoder:
|
||||
raise ValueError(f"Model does not have vision encoder, but found tensor {name}")
|
||||
# need to transform vision tensor naming, so that modify_tensors() logic can be used correctly
|
||||
name = name.replace("thinker.visual.", "model.visual.")
|
||||
if ".merger_list." in name:
|
||||
name = name.replace(".merger_list.", ".deepstack_merger_list.")
|
||||
name = name.replace(".ln_q", ".norm")
|
||||
name = name.replace(".mlp.0", ".linear_fc1")
|
||||
name = name.replace(".mlp.2", ".linear_fc2")
|
||||
elif ".merger." in name:
|
||||
name = name.replace(".ln_q", ".norm")
|
||||
name = name.replace(".mlp.0", ".linear_fc1")
|
||||
name = name.replace(".mlp.2", ".linear_fc2")
|
||||
yield from Qwen3VLVisionModel.modify_tensors(self, data_torch, name, bid)
|
||||
elif "audio_tower." in name:
|
||||
if not self.has_audio_encoder:
|
||||
raise ValueError(f"Model does not have audio encoder, but found tensor {name}")
|
||||
if "conv2d" in name and name.endswith(".bias"):
|
||||
# transform conv2d bias [n_embd] --> [1, 1, n_embd]
|
||||
data_torch = data_torch.unsqueeze(-1).unsqueeze(-1)
|
||||
yield from Qwen25AudioModel.modify_tensors(self, data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3ASRForConditionalGeneration")
|
||||
class Qwen3ASRMmprojModel(Qwen3OmniMmprojModel):
|
||||
has_audio_encoder = True
|
||||
has_vision_encoder = False
|
||||
|
||||
|
||||
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration")
|
||||
@@ -4992,6 +5067,8 @@ class Step3VLVisionModel(MmprojModel):
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if ".position_embd." in new_name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if ("mm.0." in new_name or "mm.1." in new_name) and new_name.endswith(".weight"):
|
||||
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
@@ -5030,9 +5107,10 @@ class Qwen3VLTextModel(Qwen3Model):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
|
||||
vision_config = self.hparams.get("vision_config", {})
|
||||
if "thinker_config" in self.hparams:
|
||||
vision_config = self.hparams["thinker_config"].get("vision_config", {})
|
||||
else:
|
||||
vision_config = self.hparams.get("vision_config", {})
|
||||
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
|
||||
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
|
||||
|
||||
@@ -5101,6 +5179,70 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3OmniMoeForConditionalGeneration")
|
||||
class Qwen3OmniMoeTextModel(Qwen3VLMoeTextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# correct BOS/EOS tokens
|
||||
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_config = json.load(f)
|
||||
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
|
||||
for token_id, data in added_tokens.items():
|
||||
if data.get("content") == "<|im_end|>":
|
||||
self.gguf_writer.add_bos_token_id(int(token_id))
|
||||
self.gguf_writer.add_eos_token_id(int(token_id))
|
||||
break
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_num_deepstack_layers(0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision and audio tensors - they go in the mmproj file
|
||||
if "visual." in name or "audio_tower." in name \
|
||||
or "talker." in name or "code2wav." in name:
|
||||
return
|
||||
|
||||
name = name.replace("thinker.", "")
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3ASRForConditionalGeneration")
|
||||
class Qwen3ASRTextModel(Qwen3VLTextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3VL
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_num_deepstack_layers(0)
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# fix chat template, use correct chatml format
|
||||
self.gguf_writer.add_chat_template("{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}")
|
||||
# correct BOS/EOS tokens
|
||||
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_config = json.load(f)
|
||||
added_tokens = tokenizer_config.get("added_tokens_decoder", {})
|
||||
for token_id, data in added_tokens.items():
|
||||
if data.get("content") == "<|im_end|>":
|
||||
self.gguf_writer.add_bos_token_id(int(token_id))
|
||||
self.gguf_writer.add_eos_token_id(int(token_id))
|
||||
break
|
||||
|
||||
def modify_tensors(self, data_torch, name, bid):
|
||||
# qwen3-omni
|
||||
name = name.replace("thinker.", "")
|
||||
|
||||
# Skip vision and audio tensors - they go in the mmproj file
|
||||
if "visual." in name or "audio_tower." in name \
|
||||
or "talker." in name or "code2wav." in name:
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
class _LinearAttentionVReorderBase(Qwen3NextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3NEXT # overridden by subclasses
|
||||
"""reorders V heads from grouped to tiled order for ggml broadcast
|
||||
@@ -11279,6 +11421,48 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
|
||||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||
|
||||
|
||||
@ModelBase.register("MERaLiON2ForConditionalGeneration")
|
||||
class MERaLiONWhisperEncoderModel(WhisperEncoderModel):
|
||||
has_vision_encoder = False
|
||||
has_audio_encoder = True
|
||||
|
||||
def get_audio_config(self) -> dict[str, Any] | None:
|
||||
return self.global_config.get("speech_config")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.MERALION)
|
||||
self.gguf_writer.add_audio_stack_factor(self.global_config.get("speech_mlp_scale_factor", 15))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("text_decoder."):
|
||||
return
|
||||
|
||||
if name.startswith("speech_encoder."):
|
||||
name = name.replace("speech_encoder.", "audio_tower.")
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
return
|
||||
|
||||
suffix = "." + name.rsplit(".", 1)[-1]
|
||||
|
||||
if name.startswith("ln_speech."):
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MM_NORM_PRE, suffix=suffix), data_torch)
|
||||
return
|
||||
|
||||
if name.startswith("speech_audio_adapter."):
|
||||
if ".mlp_adapter.0." in name:
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 0, suffix=suffix), data_torch)
|
||||
elif ".gate_proj." in name:
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 1, suffix=suffix), data_torch)
|
||||
elif ".pool_proj." in name:
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 2, suffix=suffix), data_torch)
|
||||
elif ".out_proj." in name:
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.A_MMPROJ, 3, suffix=suffix), data_torch)
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("VoxtralForConditionalGeneration")
|
||||
class VoxtralWhisperEncoderModel(WhisperEncoderModel):
|
||||
has_vision_encoder = False # no vision encoder
|
||||
|
||||
@@ -52,10 +52,39 @@
|
||||
}
|
||||
},
|
||||
|
||||
{
|
||||
"name": "arm64-linux-snapdragon",
|
||||
"hidden": true,
|
||||
"architecture": { "value": "arm64", "strategy": "external" },
|
||||
"toolset": { "value": "host=x86_64", "strategy": "external" },
|
||||
"cacheVariables": {
|
||||
"CMAKE_TOOLCHAIN_FILE": "cmake/arm64-linux-clang.cmake",
|
||||
"CMAKE_C_FLAGS": "-march=armv8 -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8 -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "linux_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "OFF",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
|
||||
{ "name": "arm64-android-snapdragon-debug" , "inherits": [ "base", "arm64-android-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-android-snapdragon-release", "inherits": [ "base", "arm64-android-snapdragon", "release" ] },
|
||||
|
||||
{ "name": "arm64-windows-snapdragon-debug" , "inherits": [ "base", "arm64-windows-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-windows-snapdragon-release", "inherits": [ "base", "arm64-windows-snapdragon", "release" ] }
|
||||
{ "name": "arm64-windows-snapdragon-release", "inherits": [ "base", "arm64-windows-snapdragon", "release" ] },
|
||||
|
||||
{ "name": "arm64-linux-snapdragon-debug" , "inherits": [ "base", "arm64-linux-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-linux-snapdragon-release", "inherits": [ "base", "arm64-linux-snapdragon", "release" ] }
|
||||
]
|
||||
}
|
||||
|
||||
@@ -236,10 +236,6 @@ build: 6a8cf8914 (6733)
|
||||
Controls whether the Hexagon backend allocates host buffers. By default, all buffers except for REPACK are host buffers.
|
||||
This option is required for testing Ops that require REPACK buffers (MUL_MAT and MUL_MAT_ID).
|
||||
|
||||
- `GGML_HEXAGON_EXPERIMENTAL=1`
|
||||
Controls whether the Hexagon backend enables experimental features.
|
||||
This option is required for enabling/testing experimental Ops (FLASH_ATTN_EXT).
|
||||
|
||||
- `GGML_HEXAGON_VERBOSE=1`
|
||||
Enables verbose logging of Ops from the backend. Example output:
|
||||
|
||||
@@ -259,11 +255,17 @@ build: 6a8cf8914 (6733)
|
||||
Allows enabling specific stages of the processing pipeline:
|
||||
|
||||
- `0x1` Enable Op Queue (i.e., queuing Ops into NPU)
|
||||
- `0x2` Enable Dynamic Quantizer (if needed for the Op)
|
||||
- `0x4` Enable Op Compute (MUL_MAT, etc.)
|
||||
- `0x2` Enable Op Compute (MUL_MAT, etc.)
|
||||
|
||||
Examples:
|
||||
|
||||
`GGML_HEXAGON_OPMASK=0x1 llama-completion ...` - Ops are enqueued but NPU-side processing is stubbed out
|
||||
`GGML_HEXAGON_OPMASK=0x3 llama-completion ...` - NPU performs dynamic quantization and skips the rest
|
||||
`GGML_HEXAGON_OPMASK=0x7 llama-completion ...` - Full queuing and processing of Ops (default)
|
||||
`GGML_HEXAGON_OPMASK=0x3 llama-completion ...` - Full queuing and processing of Ops (default)
|
||||
|
||||
- `GGML_HEXAGON_OPFILTER=regex`
|
||||
Allows filtering (disabling) Ops that match the regex pattern:
|
||||
|
||||
Examples:
|
||||
|
||||
`GGML_HEXAGON_OPFILTER="FLASH_ATTN_EXT" llama-completion ...` - Disable Flash Attention on Hexagon (falls back to CPU or GPU)
|
||||
`GGML_HEXAGON_OPFILTER="ADD\|SUB" llama-completion ...` - Disable ADD and SUB on Hexagon (fall back to CPU or GPU)
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
# Snapdragon-based Linux devices
|
||||
|
||||
## Docker Setup
|
||||
|
||||
The easiest way to build llama.cpp for a Snapdragon-based Linux device is using the toolchain Docker image (see [github.com/snapdragon-toolchain](https://github.com/snapdragon-toolchain)).
|
||||
This image includes OpenCL SDK, Hexagon SDK, CMake, and the ARM64 Linux cross-compilation toolchain.
|
||||
|
||||
Cross-compilation is supported on **Linux X86** hosts. The resulting binaries are deployed to and run on the target **Qualcomm Snapdragon ARM64 Linux** device.
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-linux:v0.1
|
||||
[d]/> cd /workspace
|
||||
```
|
||||
|
||||
Note: The rest of the **Linux** build process assumes that you're running inside the toolchain container.
|
||||
|
||||
|
||||
## How to Build
|
||||
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
[d]/workspace> cmake --preset arm64-linux-snapdragon-release -B build-snapdragon
|
||||
|
||||
[d]/workspace> cmake --build build-snapdragon -j $(nproc)
|
||||
```
|
||||
|
||||
To generate an installable "package" simply use cmake --install, then zip it:
|
||||
|
||||
```
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon
|
||||
[d]/workspace> zip -r pkg-snapdragon.zip pkg-snapdragon
|
||||
```
|
||||
|
||||
## How to Install
|
||||
|
||||
For this step, you will deploy the built binaries and libraries to the target Linux device. Transfer `pkg-snapdragon.zip` to the target device, then unzip it and set up the environment variables:
|
||||
|
||||
```
|
||||
$ unzip pkg-snapdragon.zip
|
||||
$ cd pkg-snapdragon
|
||||
$ export LD_LIBRARY_PATH=./lib
|
||||
$ export ADSP_LIBRARY_PATH=./lib
|
||||
```
|
||||
|
||||
At this point, you should also download some models onto the device:
|
||||
|
||||
```
|
||||
$ wget https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf
|
||||
```
|
||||
|
||||
## How to Run
|
||||
Next, since we have setup the environment variables, we can run the llama-cli with the Hexagon backends:
|
||||
```
|
||||
$ ./bin/llama-cli -m Llama-3.2-3B-Instruct-Q4_0.gguf --device HTP0 -ngl 99 -p "what is the most popular cookie in the world?"
|
||||
```
|
||||
@@ -5,6 +5,7 @@ Adding a model requires few steps:
|
||||
1. Convert the model to GGUF
|
||||
2. Define the model architecture in `llama.cpp`
|
||||
3. Build the GGML graph implementation
|
||||
4. Optional: Add multimodal encoder implementation
|
||||
|
||||
After following these steps, you can open PR.
|
||||
|
||||
@@ -114,6 +115,21 @@ Some `ggml` backends do not support all operations. Backend implementations can
|
||||
|
||||
Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).
|
||||
|
||||
### 4. Optional: Add multimodal encoder implementation
|
||||
|
||||
If the new model supports multimodal inputs, you will need to add a new encoder definition in `libmtmd`. You can find more information about llama.cpp's multimodal support in [the docs](../multimodal.md) and in the `tools/mtmd` source directory.
|
||||
|
||||
1. In the conversion script, make sure you add a subclass that extends `MmprojModel` or another class that inherits from the same base class.
|
||||
2. Add the encoder definition in `clip.cpp`.
|
||||
3. Implement the preprocessor in `mtmd.cpp`. In most cases, you can reuse an existing preprocessor.
|
||||
4. Implement the encoder GGML graph, either in a dedicated file if the model is truly different from existing ones, or by reusing an existing implementation (for example: siglip, pixtral, or qwen) and adding a model-specific projector.
|
||||
|
||||
Note:
|
||||
- Many multimodal encoders are based on models that are already supported. Make sure to read the existing encoder definitions in `tools/mtmd/models` before adding a new one. In `libmtmd`, it is generally better to extend an existing model than to duplicate code.
|
||||
- To debug the multimodal preprocessor and encoder, you can use [llama-mtmd-debug](tools/mtmd/debug/mtmd-debug.cpp).
|
||||
- Adding a model-specific API or CLI is an anti-pattern in `libmtmd`. The goal of `libmtmd` is to provide an easy-to-use, model-agnostic library for multimodal pipeline.
|
||||
- In most cases, `llama-mtmd-cli` should not be modified. If a model requires a specific prompt, either let the user provide it or bake it into the Jinja chat template.
|
||||
|
||||
## GGUF specification
|
||||
|
||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md
|
||||
|
||||
@@ -94,6 +94,11 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
# Moondream2 20250414 version
|
||||
(tool_name) -hf ggml-org/moondream2-20250414-GGUF
|
||||
|
||||
# Gemma 4
|
||||
(tool_name) -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-E4B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-26B-A4B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-31B-it-GGUF
|
||||
```
|
||||
|
||||
**Audio models**:
|
||||
@@ -109,6 +114,10 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
|
||||
# Mistral's Voxtral
|
||||
(tool_name) -hf ggml-org/Voxtral-Mini-3B-2507-GGUF
|
||||
|
||||
# Qwen3-ASR
|
||||
(tool_name) -hf ggml-org/Qwen3-ASR-0.6B-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen3-ASR-1.7B-GGUF
|
||||
```
|
||||
|
||||
**Mixed modalities**:
|
||||
@@ -118,6 +127,16 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF
|
||||
|
||||
# Qwen3 Omni
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen3-Omni-30B-A3B-Thinking-GGUF
|
||||
|
||||
# Gemma 4
|
||||
# Capabilities: audio input, vision input
|
||||
(tool_name) -hf ggml-org/gemma-4-E2B-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-4-E4B-it-GGUF
|
||||
```
|
||||
|
||||
## Finding more models:
|
||||
|
||||
@@ -664,6 +664,7 @@ void ggml_compute_forward_add(
|
||||
{
|
||||
ggml_compute_forward_add_non_quantized(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -1113,6 +1114,7 @@ void ggml_compute_forward_add1(
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -1242,6 +1244,7 @@ void ggml_compute_forward_acc(
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -4331,6 +4334,7 @@ void ggml_compute_forward_out_prod(
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -4606,6 +4610,7 @@ void ggml_compute_forward_set(
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
@@ -58,26 +58,48 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
bool is_capturing = false;
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does.
|
||||
// See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149
|
||||
// TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release.
|
||||
cudaStreamCaptureStatus capture_status;
|
||||
CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status));
|
||||
is_capturing = (capture_status != cudaStreamCaptureStatusNone);
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(
|
||||
nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||
stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream));
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols, 0, sizeof(float) * 8, stream));
|
||||
} else if (is_capturing) {
|
||||
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
|
||||
} else {
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream));
|
||||
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
|
||||
temp_keys, temp_indices, dst, ncols * nrows, nrows,
|
||||
offset_iterator, offset_iterator + 1, stream));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1185,7 +1185,9 @@ struct ggml_cuda_graph {
|
||||
bool warmup_complete = false;
|
||||
struct node_properties {
|
||||
ggml_tensor node;
|
||||
void * node_src_data_ptrs[GGML_MAX_SRC];
|
||||
void * node_src_data_ptrs[GGML_MAX_SRC];
|
||||
int64_t node_src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
|
||||
size_t node_src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
|
||||
};
|
||||
std::vector<node_properties> node_props;
|
||||
|
||||
|
||||
+19
-11
@@ -75,13 +75,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if constexpr (DKQ <= 256) {
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
||||
return;
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
||||
return;
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio > 4) {
|
||||
@@ -94,12 +98,16 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio > 1) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if constexpr (DKQ <= 256) {
|
||||
if (use_gqa_opt && gqa_ratio > 1) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -3070,16 +3070,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
ggml_cuda_graph::node_properties prop = {};
|
||||
memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
|
||||
// if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see:
|
||||
// https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188
|
||||
for (int j = 0; j < GGML_MAX_SRC; ++j) {
|
||||
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr;
|
||||
if (cgraph->nodes[i]->src[j]) {
|
||||
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data;
|
||||
memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j]));
|
||||
memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j]));
|
||||
}
|
||||
}
|
||||
|
||||
if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
|
||||
if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
|
||||
graph->node_props[i] = prop;
|
||||
res = true;
|
||||
}
|
||||
graph->node_props[i] = prop;
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
@@ -134,8 +134,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
switch (nc) {
|
||||
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
|
||||
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
|
||||
case 5: launch_kernel(std::integral_constant<int, 5>{}); break;
|
||||
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
|
||||
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
|
||||
default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,59 +14,42 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define htp_act_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
#define htp_act_preamble2 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
#define htp_act_preamble \
|
||||
const struct htp_tensor * src0 = actx->octx->src[0]; \
|
||||
const struct htp_tensor * src1 = actx->octx->src[1]; \
|
||||
const struct htp_tensor * dst = actx->octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1 ? src1->ne[0] : 0; \
|
||||
const uint32_t ne11 = src1 ? src1->ne[1] : 0; \
|
||||
const uint32_t ne12 = src1 ? src1->ne[2] : 0; \
|
||||
const uint32_t ne13 = src1 ? src1->ne[3] : 0; \
|
||||
\
|
||||
const uint32_t nb10 = src1 ? src1->nb[0] : 0; \
|
||||
const uint32_t nb11 = src1 ? src1->nb[1] : 0; \
|
||||
const uint32_t nb12 = src1 ? src1->nb[2] : 0; \
|
||||
const uint32_t nb13 = src1 ? src1->nb[3] : 0; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct htp_act_context {
|
||||
@@ -97,10 +80,7 @@ struct htp_act_context {
|
||||
|
||||
static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
htp_act_preamble;
|
||||
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
@@ -207,10 +187,7 @@ static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
htp_act_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -332,9 +309,7 @@ static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, vo
|
||||
|
||||
static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
htp_act_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -433,9 +408,7 @@ static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
htp_act_preamble;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -533,10 +506,7 @@ static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
htp_act_preamble;
|
||||
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
@@ -652,9 +622,9 @@ static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
}
|
||||
|
||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) {
|
||||
FARF(ERROR, "Non-contiguous tensors are not supported at this time \n");
|
||||
@@ -697,25 +667,20 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
size_t src0_row_size = src0->nb[1];
|
||||
size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used
|
||||
size_t src1_row_size = src1 ? src1->nb[1] : src0->nb[1];
|
||||
size_t dst_row_size = dst->nb[1];
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
if (!src1_valid) {
|
||||
src1_row_size = src0_row_size;
|
||||
}
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
|
||||
size_t spad_size_per_row = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned;
|
||||
size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if(vtcm_row_per_thread ==0){
|
||||
if (vtcm_row_per_thread == 0) {
|
||||
FARF(ERROR, "act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size_per_row * n_threads);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
@@ -733,7 +698,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.src = NULL;
|
||||
|
||||
if (src1) {
|
||||
FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
@@ -773,9 +742,9 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
|
||||
// Pointers and GLU logic
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * data_src1 = (const uint8_t *) src1->data;
|
||||
const uint8_t * data_src1 = src1 ? (const uint8_t *) src1->data : NULL;
|
||||
|
||||
if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
|
||||
if (!src1 && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
|
||||
const int32_t swapped = octx->op_params[1];
|
||||
data_src1 = data_src0;
|
||||
actx.src1_row_size = actx.src0_row_size;
|
||||
@@ -799,7 +768,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
int op_activations(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_activations_f32(octx);
|
||||
break;
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "hex-dma.h"
|
||||
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifndef MIN
|
||||
@@ -175,8 +175,8 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = actx->octx;
|
||||
|
||||
// Unpack context
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
// Scratchpad memory
|
||||
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
|
||||
@@ -249,16 +249,16 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
|
||||
int op_argsort(struct htp_ops_context * octx) {
|
||||
// Check supported types
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
if (octx->src[0]->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
|
||||
const uint32_t total_rows = octx->src[0]->ne[1] * octx->src[0]->ne[2] * octx->src[0]->ne[3];
|
||||
const uint32_t n_threads = MIN(total_rows, octx->n_threads);
|
||||
|
||||
// Allocate scratchpad
|
||||
// We need 1 row of float + 1 row of int32 per thread.
|
||||
uint32_t ne00 = octx->src0.ne[0];
|
||||
uint32_t ne00 = octx->src[0]->ne[0];
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
|
||||
size_t spad_per_thread = values_size + indices_size;
|
||||
@@ -278,9 +278,9 @@ int op_argsort(struct htp_ops_context * octx) {
|
||||
octx->src0_spad.size_per_thread = spad_per_thread;
|
||||
|
||||
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
|
||||
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
|
||||
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
|
||||
octx->src0.data, octx->dst.data);
|
||||
octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3],
|
||||
octx->dst->ne[0], octx->dst->ne[1], octx->dst->ne[2], octx->dst->ne[3],
|
||||
octx->src[0]->data, octx->dst->data);
|
||||
|
||||
struct htp_argsort_context actx;
|
||||
actx.octx = octx;
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifndef MIN
|
||||
@@ -43,10 +43,10 @@ struct htp_binary_context {
|
||||
bool split_at_ne02;
|
||||
};
|
||||
|
||||
#define htp_binary_preamble \
|
||||
const struct htp_tensor * src0 = &octx->src0; \
|
||||
const struct htp_tensor * src1 = &octx->src1; \
|
||||
struct htp_tensor * dst = &octx->dst; \
|
||||
#define htp_binary_preamble \
|
||||
const struct htp_tensor * src0 = octx->src[0]; \
|
||||
const struct htp_tensor * src1 = octx->src[1]; \
|
||||
const struct htp_tensor * dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
@@ -181,7 +181,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
@@ -274,7 +274,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
@@ -374,7 +374,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
@@ -455,7 +455,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
const uint32_t start_row = bctx->nrows_per_thread * ith;
|
||||
@@ -540,7 +540,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
htp_binary_preamble;
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
|
||||
const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
|
||||
const uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
@@ -629,10 +629,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_binary_context * bctx = (struct htp_binary_context *) data;
|
||||
struct htp_ops_context * octx = bctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * src2 = octx->src[2];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t ne00 = src0->ne[0];
|
||||
const uint32_t ne01 = src0->ne[1];
|
||||
@@ -723,15 +723,15 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
}
|
||||
|
||||
static int execute_op_binary(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
// Use packed row sizes for VTCM allocation
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
|
||||
const size_t src0_row_size = src0->ne[0] * elem_size;
|
||||
const size_t src1_row_size = src1->ne[0] * elem_size;
|
||||
@@ -799,9 +799,9 @@ static int execute_op_binary(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
return HTP_STATUS_OK;
|
||||
@@ -857,12 +857,12 @@ static int execute_op_binary(struct htp_ops_context * octx) {
|
||||
int op_binary(struct htp_ops_context * octx) {
|
||||
|
||||
// Does not support permutations of src1
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
if (src1->nb[1] < src1->nb[0]) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t src0_type = octx->src0.type;
|
||||
const uint32_t src0_type = octx->src[0]->type;
|
||||
if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {
|
||||
return execute_op_binary(octx);
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
@@ -32,10 +32,10 @@ struct htp_copy_context {
|
||||
void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);
|
||||
};
|
||||
|
||||
#define cpy_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0; \
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
\
|
||||
#define cpy_preamble \
|
||||
const struct htp_tensor *src0 = octx->src[0]; \
|
||||
const struct htp_tensor *dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-dma.h"
|
||||
|
||||
#define htp_cumsum_tensors_preamble \
|
||||
struct htp_tensor * restrict src0 = &octx->src0; \
|
||||
struct htp_tensor * restrict dst = &octx->dst; \
|
||||
#define htp_cumsum_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
@@ -206,8 +206,8 @@ static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
}
|
||||
|
||||
int op_cumsum_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
@@ -226,10 +226,12 @@ int op_cumsum_f32(struct htp_ops_context * octx) {
|
||||
|
||||
octx->src0_spad.size_per_thread = src_row_size_aligned * 2;
|
||||
octx->dst_spad.size_per_thread = dst_row_size_aligned * 2;
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
struct htp_cumsum_context cctx = {
|
||||
.octx = octx,
|
||||
@@ -251,8 +253,9 @@ int op_cumsum_f32(struct htp_ops_context * octx) {
|
||||
}
|
||||
|
||||
int op_cumsum(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (dst->type) {
|
||||
case HTP_TYPE_F32:
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
// Must be multiple of 32
|
||||
@@ -278,12 +278,12 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t *
|
||||
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_fa_context * factx = (struct htp_fa_context *) data;
|
||||
const struct htp_ops_context * octx = factx->octx;
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * mask = octx->src[3];
|
||||
const struct htp_tensor * sinks = octx->src[4];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t neq0 = q->ne[0];
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
@@ -610,11 +610,11 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * mask = octx->src[3];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
// Check support
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
|
||||
@@ -701,13 +701,11 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
||||
|
||||
// FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread);
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->src2_spad.src = NULL;
|
||||
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->src3_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
@@ -23,27 +23,33 @@ struct get_rows_context {
|
||||
};
|
||||
|
||||
#define get_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t ne00 = octx->src[0]->ne[0]; \
|
||||
const uint32_t ne01 = octx->src[0]->ne[1]; \
|
||||
const uint32_t ne02 = octx->src[0]->ne[2]; \
|
||||
const uint32_t ne03 = octx->src[0]->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src[1]->ne[0]; \
|
||||
const uint32_t ne11 = octx->src[1]->ne[1]; \
|
||||
const uint32_t ne12 = octx->src[1]->ne[2]; \
|
||||
const uint32_t ne13 = octx->src[1]->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = octx->dst->ne[0]; \
|
||||
const uint32_t ne1 = octx->dst->ne[1]; \
|
||||
const uint32_t ne2 = octx->dst->ne[2]; \
|
||||
const uint32_t ne3 = octx->dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src[0]->nb[1]; \
|
||||
const uint32_t nb02 = octx->src[0]->nb[2]; \
|
||||
const uint32_t nb03 = octx->src[0]->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src[1]->nb[0]; \
|
||||
const uint32_t nb11 = octx->src[1]->nb[1]; \
|
||||
const uint32_t nb12 = octx->src[1]->nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst->nb[1]; \
|
||||
const uint32_t nb2 = octx->dst->nb[2]; \
|
||||
const uint32_t nb3 = octx->dst->nb[3]; \
|
||||
\
|
||||
const uint32_t nr = ne10 * ne11 * ne12;
|
||||
|
||||
static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
@@ -51,12 +57,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
struct htp_ops_context * octx = grctx->octx;
|
||||
get_rows_preamble;
|
||||
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
// parallelize by src1 elements (which correspond to dst rows)
|
||||
const uint32_t dr = grctx->src1_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
|
||||
@@ -64,7 +72,7 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
|
||||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
|
||||
@@ -73,10 +81,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
|
||||
}
|
||||
|
||||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
@@ -84,15 +96,15 @@ int op_get_rows(struct htp_ops_context * octx) {
|
||||
|
||||
const uint32_t n_threads = MIN(nr, octx->n_threads);
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
if (octx->src[0]->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32) {
|
||||
if (octx->dst->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
@@ -102,8 +114,8 @@ int op_get_rows(struct htp_ops_context * octx) {
|
||||
|
||||
struct get_rows_context grctx;
|
||||
grctx.octx = octx;
|
||||
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]);
|
||||
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]);
|
||||
|
||||
grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <qurt_memory.h>
|
||||
|
||||
#include "hexagon_types.h"
|
||||
#include "hexagon_protos.h"
|
||||
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-dump.h"
|
||||
@@ -68,4 +70,23 @@ static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride,
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
}
|
||||
|
||||
#define HEX_L2_LINE_SIZE 64
|
||||
#define HEX_L2_FLUSH_SIZE (128 * 1024)
|
||||
|
||||
static inline void hex_l2flush(void * addr, size_t size)
|
||||
{
|
||||
if (size > HEX_L2_FLUSH_SIZE) {
|
||||
qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE);
|
||||
} else {
|
||||
const uint32_t s = (uint32_t) addr;
|
||||
const uint32_t e = s + size;
|
||||
for (uint32_t i = s; i < e; i += HEX_L2_LINE_SIZE * 4) {
|
||||
Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 0);
|
||||
Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 1);
|
||||
Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 2);
|
||||
Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HEX_UTILS_H */
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
#include "hvx-dump.h"
|
||||
#include "worker-pool.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#include "hmx-utils.h"
|
||||
#include "hmx-ops.h"
|
||||
@@ -821,7 +821,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
|
||||
// and each q_head is computed individually to avoid tile-major packing
|
||||
// issues. m_chunk_n_rows is always a multiple of 32 (from
|
||||
// hmx_compute_chunks), so per-head tile arrays don't overlap.
|
||||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
const size_t vec_dot_size = params->k * sizeof(__fp16);
|
||||
|
||||
// When the activation has a large stride (e.g. permuted Q tensor with
|
||||
@@ -998,7 +998,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
|
||||
}
|
||||
|
||||
// --- Dynamic VTCM layout ---
|
||||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
|
||||
// DMA-based activation gather for strided tensors (see batched path comment).
|
||||
@@ -1182,7 +1182,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type);
|
||||
|
||||
// --- Dynamic VTCM layout ---
|
||||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||||
const bool use_pipeline = (m >= 128) && (k <= n);
|
||||
|
||||
@@ -1273,9 +1273,6 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
void *buf_curr = vtcm_scratch0;
|
||||
void *buf_next = vtcm_scratch1;
|
||||
|
||||
// issue async DDR data transfer for the first weight chunk
|
||||
// NOTE: use 2D DMA (n_cols rows x row_stride bytes) instead of 1D
|
||||
// because UDMA roiwidth is 16-bit and total size can exceed 65535.
|
||||
{
|
||||
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
|
||||
@@ -1533,20 +1530,15 @@ void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, co
|
||||
worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads);
|
||||
}
|
||||
|
||||
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m,
|
||||
int k, int n, int weight_type) {
|
||||
// Runtime check -- k >= 16384 exceeds 2D DMA limit
|
||||
if (k >= 16384) {
|
||||
FARF(HIGH, "%s: k=%d exceeds 2D DMA limit", __func__, k);
|
||||
return -1;
|
||||
}
|
||||
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w,
|
||||
int m, int k, int n, int weight_type) {
|
||||
// assume k % 32 == 0 && n % 32 == 0
|
||||
const size_t row_stride = get_x4x2_row_stride(weight_type, k);
|
||||
if (row_stride == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||||
const size_t vtcm_budget = ctx->vtcm_size;
|
||||
|
||||
const size_t M_BLOCK_SIZE = 512;
|
||||
const size_t N_BLOCK_SIZE = 512;
|
||||
@@ -1576,8 +1568,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
|
||||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||||
assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget);
|
||||
|
||||
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu",
|
||||
__func__, m, k, n, weight_type,
|
||||
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", __func__, m, k, n, weight_type,
|
||||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
// initialize eye tile (32x32 identity matrix)
|
||||
|
||||
@@ -7,16 +7,12 @@
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifndef restrict
|
||||
# define restrict __restrict
|
||||
#endif
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct htp_context; // forward declaration
|
||||
|
||||
typedef struct {
|
||||
float *dst;
|
||||
const float *activation;
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define HTP_CTX_H
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "htp-ops.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include <assert.h>
|
||||
@@ -10,38 +11,85 @@
|
||||
#include <stdint.h>
|
||||
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
#define HTP_MAX_MMAPS 16
|
||||
|
||||
// Memory mapping
|
||||
struct htp_mmap {
|
||||
uint64_t size;
|
||||
uint64_t base;
|
||||
uint32_t fd;
|
||||
uint32_t pinned;
|
||||
};
|
||||
|
||||
// Scratchpad state
|
||||
struct htp_spad {
|
||||
const struct htp_tensor * src; // original src of the data (for reuse)
|
||||
uint8_t * data; // pointer to an area in vtcm
|
||||
uint32_t stride; // stride used inside this spad
|
||||
uint32_t size; // total size
|
||||
uint32_t size_per_thread; // size per thread
|
||||
};
|
||||
|
||||
// Context while processing an Op
|
||||
// TODO: fold this into the main context
|
||||
struct htp_ops_context {
|
||||
struct htp_context * ctx;
|
||||
|
||||
enum htp_op_code op; // FIXME: rename to opcode
|
||||
int32_t op_params[HTP_OP_MAX_PARAMS];
|
||||
|
||||
const struct htp_tensor * src[HTP_OP_MAX_INPUTS];
|
||||
const struct htp_tensor * dst;
|
||||
|
||||
// TODO convert these to an array
|
||||
struct htp_spad src0_spad;
|
||||
struct htp_spad src1_spad;
|
||||
struct htp_spad src2_spad;
|
||||
struct htp_spad src3_spad;
|
||||
struct htp_spad dst_spad;
|
||||
|
||||
uint32_t n_threads;
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
// Main context for htp DSP backend
|
||||
struct htp_context {
|
||||
dspqueue_t queue;
|
||||
dma_queue * dma[HTP_MAX_NTHREADS];
|
||||
worker_pool_context_t worker_pool;
|
||||
uint32_t n_threads;
|
||||
dspqueue_t queue;
|
||||
dma_queue * dma[HTP_MAX_NTHREADS];
|
||||
struct htp_mmap mmap[HTP_MAX_MMAPS];
|
||||
worker_pool_context_t worker_pool;
|
||||
uint32_t n_threads;
|
||||
|
||||
int thread_id;
|
||||
int thread_prio;
|
||||
int thread_id;
|
||||
int thread_prio;
|
||||
|
||||
uint8_t * vtcm_base;
|
||||
size_t vtcm_size;
|
||||
uint32_t vtcm_rctx;
|
||||
int hmx_enabled;
|
||||
|
||||
atomic_bool vtcm_valid;
|
||||
atomic_bool vtcm_inuse;
|
||||
atomic_bool vtcm_needs_release;
|
||||
uint8_t * vtcm_base;
|
||||
size_t vtcm_size;
|
||||
uint32_t vtcm_rctx;
|
||||
atomic_bool vtcm_valid;
|
||||
atomic_bool vtcm_needs_release;
|
||||
|
||||
uint32_t opmask;
|
||||
|
||||
// Cached src1 spad position from the last quantize pass.
|
||||
// When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM
|
||||
// at this address; the matmul must read from here instead of recomputing
|
||||
// the offset (which depends on the current op's src0 size).
|
||||
uint8_t * prev_src1_spad;
|
||||
|
||||
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
|
||||
#ifdef HTP_HAS_HMX
|
||||
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
|
||||
size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation)
|
||||
#endif
|
||||
struct htp_ops_context octx;
|
||||
};
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
int op_activations(struct htp_ops_context * octx);
|
||||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
int op_rope(struct htp_ops_context * octx);
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_repeat(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
int op_cumsum(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_CTX_H */
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
#ifndef HTP_MSG_H
|
||||
#define HTP_MSG_H
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
// ggml-common.h must be included prio to this header
|
||||
|
||||
// Mask to enable various stages of the Ops.
|
||||
// Used for debugging and profiling.
|
||||
enum {
|
||||
HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP)
|
||||
HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize
|
||||
HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute
|
||||
};
|
||||
|
||||
// Op flags
|
||||
enum {
|
||||
HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors)
|
||||
HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling)
|
||||
HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification
|
||||
};
|
||||
|
||||
enum htp_status {
|
||||
HTP_STATUS_OK = 1,
|
||||
HTP_STATUS_INTERNAL_ERR = 2,
|
||||
HTP_STATUS_NO_SUPPORT = 3,
|
||||
HTP_STATUS_INVAL_PARAMS = 4,
|
||||
HTP_STATUS_VTCM_TOO_SMALL = 5,
|
||||
};
|
||||
|
||||
// The values must match the ggml_type.
|
||||
// Duplicated here because we can't include full ggml.h in the htp build.
|
||||
// We have some static_asserts in the cpp code to ensure things are in sync.
|
||||
enum htp_data_type {
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_IQ4_NL = 20,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
||||
// Do not reorder first 4 (used as an index)
|
||||
enum htp_op {
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_UNARY_SILU,
|
||||
HTP_OP_UNARY_GELU,
|
||||
HTP_OP_UNARY_SIGMOID,
|
||||
HTP_OP_UNARY_EXP,
|
||||
HTP_OP_UNARY_NEG,
|
||||
HTP_OP_UNARY_SOFTPLUS,
|
||||
HTP_OP_GLU_SWIGLU,
|
||||
HTP_OP_GLU_SWIGLU_OAI,
|
||||
HTP_OP_GLU_GEGLU,
|
||||
HTP_OP_SOFTMAX,
|
||||
HTP_OP_ADD_ID,
|
||||
HTP_OP_ROPE,
|
||||
HTP_OP_FLASH_ATTN_EXT,
|
||||
HTP_OP_SET_ROWS,
|
||||
HTP_OP_GET_ROWS,
|
||||
HTP_OP_SCALE,
|
||||
HTP_OP_CPY,
|
||||
HTP_OP_ARGSORT,
|
||||
HTP_OP_SQR,
|
||||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
HTP_OP_SSM_CONV,
|
||||
HTP_OP_REPEAT,
|
||||
HTP_OP_CUMSUM,
|
||||
INVALID
|
||||
};
|
||||
|
||||
static inline size_t htp_t_block_size(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return 1;
|
||||
case HTP_TYPE_F16:
|
||||
return 1;
|
||||
case HTP_TYPE_Q4_0:
|
||||
return QK4_0;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return QK8_0;
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return QK4_NL;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return QK_MXFP4;
|
||||
default:
|
||||
assert(0 && "unsupported HTP data type");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static inline size_t htp_type_nbytes(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return 4;
|
||||
case HTP_TYPE_F16:
|
||||
return 2;
|
||||
case HTP_TYPE_Q4_0:
|
||||
return sizeof(block_q4_0);
|
||||
case HTP_TYPE_Q8_0:
|
||||
return sizeof(block_q8_0);
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return sizeof(block_iq4_nl);
|
||||
case HTP_TYPE_MXFP4:
|
||||
return sizeof(block_mxfp4);
|
||||
default:
|
||||
assert(0 && "unsupported HTP data type");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
|
||||
|
||||
#define HTP_MAX_DIMS 4
|
||||
|
||||
struct htp_tensor {
|
||||
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
|
||||
uint32_t type; // Data type
|
||||
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
|
||||
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
|
||||
};
|
||||
|
||||
#define HTP_MAX_OP_PARAMS 64
|
||||
|
||||
struct htp_general_req {
|
||||
uint32_t op; // GGML/HTP Op
|
||||
int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
// Params for the op, e.g. epsilon of RMS norm
|
||||
uint32_t flags; // Request flags
|
||||
|
||||
struct htp_tensor src0; // Input0 tensor
|
||||
struct htp_tensor src1; // Input1 tensor
|
||||
struct htp_tensor src2; // Input2 tensor
|
||||
struct htp_tensor src3; // Input3 tensor
|
||||
struct htp_tensor src4; // Input4 tensor
|
||||
struct htp_tensor dst; // Output tensor
|
||||
|
||||
// should be multiple of 64 bytes (cacheline)
|
||||
};
|
||||
|
||||
struct htp_general_rsp {
|
||||
uint32_t op; // GGML/HTP Op
|
||||
uint32_t status; // HTP_STATUS_...
|
||||
uint32_t prof_usecs; // Number of usec per request
|
||||
uint32_t prof_cycles; // Number of cycles per request
|
||||
uint32_t prof_pkts; // Number of instruction packets per request
|
||||
uint8_t unused[44]; // Pad to 64 bytes
|
||||
};
|
||||
|
||||
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
|
||||
#define HTP_MAX_PACKET_BUFFERS 8
|
||||
|
||||
#endif /* HTP_MSG_H */
|
||||
@@ -1,65 +1,154 @@
|
||||
#ifndef HTP_OPS_H
|
||||
#define HTP_OPS_H
|
||||
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <hex-fastdiv.h>
|
||||
// ggml-common.h must be included prio to this header
|
||||
|
||||
// ggml-common.h must be included prior to this header
|
||||
|
||||
struct htp_spad {
|
||||
uint8_t * data;
|
||||
size_t stride;
|
||||
size_t size;
|
||||
size_t size_per_thread;
|
||||
enum htp_status {
|
||||
HTP_STATUS_OK = 1,
|
||||
HTP_STATUS_INTERNAL_ERR = 2,
|
||||
HTP_STATUS_NO_SUPPORT = 3,
|
||||
HTP_STATUS_INVAL_PARAMS = 4,
|
||||
HTP_STATUS_VTCM_TOO_SMALL = 5,
|
||||
};
|
||||
|
||||
struct htp_ops_context {
|
||||
struct htp_context * ctx;
|
||||
// First set of values must match the ggml_type.
|
||||
// Duplicated here because we can't include full ggml.h in the htp build.
|
||||
// We have some static_asserts in the cpp code to ensure things are in sync.
|
||||
enum htp_data_type {
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_IQ4_NL = 20,
|
||||
HTP_TYPE_I32 = 26,
|
||||
HTP_TYPE_I64 = 27,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
|
||||
enum htp_op op;
|
||||
int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
// types used internally for repack, dyn.quant, etc
|
||||
HTP_TYPE_Q4_0x4x2 = 200,
|
||||
HTP_TYPE_Q8_0x4x2,
|
||||
HTP_TYPE_MXFP4x4x2,
|
||||
|
||||
struct htp_tensor src0;
|
||||
struct htp_tensor src1;
|
||||
struct htp_tensor src2;
|
||||
struct htp_tensor src3;
|
||||
struct htp_tensor src4;
|
||||
struct htp_tensor dst;
|
||||
|
||||
struct htp_spad src0_spad;
|
||||
struct htp_spad src1_spad;
|
||||
struct htp_spad src2_spad;
|
||||
struct htp_spad src3_spad;
|
||||
struct htp_spad dst_spad;
|
||||
|
||||
worker_pool_context_t * wpool; // worker pool
|
||||
uint32_t n_threads; // num threads
|
||||
|
||||
uint32_t flags;
|
||||
HTP_TYPE_INVALID
|
||||
};
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
int op_activations(struct htp_ops_context * octx);
|
||||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
int op_rope(struct htp_ops_context * octx);
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_repeat(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
int op_ssm_conv(struct htp_ops_context * octx);
|
||||
int op_cumsum(struct htp_ops_context * octx);
|
||||
// Constats for internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
|
||||
|
||||
|
||||
// Mask to enable various stages of the Ops.
|
||||
// Used for debugging and profiling.
|
||||
enum htp_op_mask {
|
||||
HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP)
|
||||
HTP_OPMASK_COMPUTE = (1 << 1), // Enable Compute
|
||||
};
|
||||
|
||||
// Do not reorder first 4 (used as an index)
|
||||
enum htp_op_code {
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_UNARY_SILU,
|
||||
HTP_OP_UNARY_GELU,
|
||||
HTP_OP_UNARY_SIGMOID,
|
||||
HTP_OP_UNARY_EXP,
|
||||
HTP_OP_UNARY_NEG,
|
||||
HTP_OP_UNARY_SOFTPLUS,
|
||||
HTP_OP_GLU_SWIGLU,
|
||||
HTP_OP_GLU_SWIGLU_OAI,
|
||||
HTP_OP_GLU_GEGLU,
|
||||
HTP_OP_SOFTMAX,
|
||||
HTP_OP_ADD_ID,
|
||||
HTP_OP_ROPE,
|
||||
HTP_OP_FLASH_ATTN_EXT,
|
||||
HTP_OP_SET_ROWS,
|
||||
HTP_OP_GET_ROWS,
|
||||
HTP_OP_SCALE,
|
||||
HTP_OP_CPY,
|
||||
HTP_OP_ARGSORT,
|
||||
HTP_OP_SQR,
|
||||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
HTP_OP_SSM_CONV,
|
||||
HTP_OP_REPEAT,
|
||||
HTP_OP_CUMSUM,
|
||||
|
||||
HTP_OP_INVALID
|
||||
};
|
||||
|
||||
#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS
|
||||
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
|
||||
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
|
||||
|
||||
#define HTP_OP_MAX_BUFS 8
|
||||
#define HTP_OP_MAX_REQS 256
|
||||
#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS)
|
||||
#define HTP_OP_MAX_VMEM (3221225472u)
|
||||
|
||||
enum htp_tensor_flags {
|
||||
HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights)
|
||||
HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU)
|
||||
};
|
||||
|
||||
// Tensor descriptor
|
||||
struct htp_tensor {
|
||||
uint32_t data; // Buffer offset in the messages, and data pointer on the NPU
|
||||
uint32_t size; // Data size in bytes
|
||||
uint32_t flags; // Buffer / tensor flags
|
||||
uint16_t type; // Data type
|
||||
uint16_t bi; // Buffer index
|
||||
uint32_t ne[HTP_OP_MAX_DIMS]; // Number of elements
|
||||
uint32_t nb[HTP_OP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
|
||||
};
|
||||
|
||||
// Buffer descriptor
|
||||
struct htp_buf_desc {
|
||||
uint64_t base; // base address
|
||||
uint64_t size; // total size
|
||||
uint32_t flags; // buffer flags (unused)
|
||||
uint32_t fd; // file descriptor
|
||||
};
|
||||
|
||||
enum htp_op_flags {
|
||||
HTP_OPFLAGS_SKIP_COMPUTE = (1U << 0), // Skip actual computation (used for profiling)
|
||||
};
|
||||
|
||||
// Op descriptor
|
||||
struct htp_op_desc {
|
||||
uint32_t opcode; // GGML/HTP Op
|
||||
uint32_t flags; // Op flags
|
||||
int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm
|
||||
uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices
|
||||
uint16_t dst; // Output tensor index
|
||||
|
||||
// the rest is filled in-place by the NPU
|
||||
uint32_t prof_usecs; // Number of usec per request
|
||||
uint32_t prof_cycles; // Number of cycles per request
|
||||
uint32_t prof_pkts; // Number of instruction packets per request
|
||||
uint32_t unused;
|
||||
};
|
||||
|
||||
struct htp_opbatch_req {
|
||||
uint32_t n_bufs; // Number of buffers
|
||||
uint32_t n_tensors; // Number of tensors
|
||||
uint32_t n_ops; // Number of ops
|
||||
uint32_t flags; // unused
|
||||
// struct htp_buf_desc bufs[]; -- dspqueue buf 0
|
||||
// struct htp_tensor tensors[]; -- dspqueue buf 0
|
||||
// struct htp_op_desc ops[]; -- dspqueue buf 0
|
||||
};
|
||||
|
||||
struct htp_opbatch_rsp {
|
||||
uint32_t status; // HTP_STATUS_...
|
||||
// struct htp_op_req ops[]; -- dspqueue buf 0
|
||||
};
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx);
|
||||
AEEResult stop();
|
||||
AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned);
|
||||
AEEResult munmap(in uint32 fd);
|
||||
AEEResult enable_etm();
|
||||
AEEResult disable_etm();
|
||||
};
|
||||
|
||||
+356
-1142
File diff suppressed because it is too large
Load Diff
@@ -16,8 +16,9 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hmx-ops.h"
|
||||
|
||||
#define MM_SPAD_SRC0_NROWS 16
|
||||
#define MM_SPAD_SRC1_NROWS 16
|
||||
@@ -1897,11 +1898,11 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
#define htp_matmul_tensors_preamble \
|
||||
struct htp_tensor * restrict src0 = &octx->src0; \
|
||||
struct htp_tensor * restrict src1 = &octx->src1; \
|
||||
struct htp_tensor * restrict src2 = &octx->src2; \
|
||||
struct htp_tensor * restrict dst = &octx->dst; \
|
||||
#define htp_matmul_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict src1 = octx->src[1]; \
|
||||
const struct htp_tensor * restrict src2 = octx->src[2]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
||||
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
||||
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
||||
@@ -2223,8 +2224,8 @@ struct mmid_row_mapping {
|
||||
static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
|
||||
struct htp_tensor * restrict ids = &octx->src2;
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -2342,8 +2343,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
|
||||
struct htp_tensor * restrict ids = &octx->src2;
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -2612,7 +2613,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
|
||||
const struct htp_tensor * src = &octx->src1;
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
struct htp_spad * spad = &octx->src0_spad;
|
||||
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
||||
@@ -2659,7 +2660,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
|
||||
const struct htp_tensor * src = &octx->src1;
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
||||
uint32_t dst_stride = octx->src1_spad.stride;
|
||||
@@ -2701,7 +2702,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
|
||||
const struct htp_tensor * src = &octx->src1;
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
|
||||
uint32_t dst_stride = octx->src1_spad.stride;
|
||||
@@ -2800,7 +2801,7 @@ static void htp_mminit_spad(struct htp_ops_context * octx,
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
||||
}
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx) {
|
||||
static int op_matmul_hvx(struct htp_ops_context * octx) {
|
||||
htp_matmul_tensors_preamble;
|
||||
|
||||
struct htp_matmul_context mmctx_struct = {0};
|
||||
@@ -2824,7 +2825,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
worker_callback_t quant_job_func;
|
||||
worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
|
||||
|
||||
bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
|
||||
bool need_quant = true;
|
||||
|
||||
if (src0->type == HTP_TYPE_F16) {
|
||||
// Try optimized f16-f16 path first (src1 in VTCM)
|
||||
@@ -2838,7 +2839,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
// Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
|
||||
// It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
|
||||
const bool is_batched = (ne02 > 1) || (ne03 > 1);
|
||||
const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
|
||||
const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]);
|
||||
|
||||
if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
|
||||
// Optimized path
|
||||
@@ -2915,34 +2916,172 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
// Place src1 spad first. We use it for dyn.quant and may reuse between ops
|
||||
octx->src1_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL;
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->dst_spad.src = NULL;
|
||||
|
||||
octx->src0_spad.stride = src0_row_size_padded;
|
||||
octx->src1_spad.stride = src1_row_size;
|
||||
|
||||
if (need_quant) {
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)
|
||||
return HTP_STATUS_OK;
|
||||
|
||||
if (need_quant && !octx->src1_spad.src) {
|
||||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
// Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
// SKIP_QUANTIZE: Q8 data lives at the address written by the previous
|
||||
// quantize pass. The current op may have a different src0 size (e.g.
|
||||
// IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong.
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
octx->src1_spad.src = src1;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
const uint32_t n_matmul_jobs = octx->n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
|
||||
}
|
||||
const uint32_t n_matmul_jobs = octx->n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx) {
|
||||
htp_matmul_tensors_preamble;
|
||||
|
||||
#ifndef HTP_HAS_HMX
|
||||
return op_matmul_hvx(octx);
|
||||
#else
|
||||
if (!octx->ctx->hmx_enabled) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// HMX weight tile requires N to be 32-aligned.
|
||||
if (src0->ne[1] % 32 != 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
|
||||
// Other types fall back to HVX.
|
||||
uint32_t wtype = src0->type;
|
||||
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// Quantised HMX path requires K aligned to 256 (x4x2 super-block).
|
||||
// F16 HMX path requires K aligned to 32 (tile width).
|
||||
if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1);
|
||||
|
||||
// Quantised HMX kernels only handle flat 2D matmul (host already rejects
|
||||
// batched quantised, but guard here too). F16 batched matmul is handled
|
||||
// by the dedicated wrapper in hmx-matmul-ops.c.
|
||||
if (is_batched && src0->type != HTP_TYPE_F16) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// HMX assumes contiguous row-major layout. Fall back for permuted
|
||||
// tensors where strides are non-monotonic (e.g. transposed KV cache).
|
||||
if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// M alignment: when M > 32 but not 32-aligned, we split into
|
||||
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
|
||||
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
|
||||
const int m_total = (int) src1->ne[1];
|
||||
const int m_tail = m_total % 32;
|
||||
const int m_hmx = m_total - m_tail;
|
||||
|
||||
if (m_hmx == 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// Always re-quantize src1 since HMX kernel overwrites vtcm/spad,
|
||||
// so any previously cached quantized data is invalid.
|
||||
octx->src1_spad.src = NULL;
|
||||
|
||||
int k = (int) src0->ne[0]; // inner dimension
|
||||
int n = (int) src0->ne[1]; // weight columns
|
||||
|
||||
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
|
||||
int ret = -1;
|
||||
|
||||
// Row strides in elements. For compact tensors these equal k; for
|
||||
// permuted attention views they can be larger, so pass the real stride.
|
||||
const int act_stride = (int)(src1->nb[1] / sizeof(float));
|
||||
const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16));
|
||||
|
||||
if (src0->type == HTP_TYPE_F16) {
|
||||
if (is_batched) {
|
||||
hmx_matmul_w16a32_batched_params_t batch_params = {
|
||||
.dst = (float *) dst->data,
|
||||
.activation = (float *) src1->data,
|
||||
.permuted_weight = (const __fp16 *) src0->data,
|
||||
.m = m_hmx,
|
||||
.k = k,
|
||||
.n = n,
|
||||
.act_stride = act_stride,
|
||||
.weight_stride = wgt_stride,
|
||||
.dst_stride = (int) (dst->nb[1] / sizeof(float)),
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.ne12 = ne12,
|
||||
.ne13 = ne13,
|
||||
.src0_nb2 = src0->nb[2],
|
||||
.src0_nb3 = src0->nb[3],
|
||||
.src1_nb2 = src1->nb[2],
|
||||
.src1_nb3 = src1->nb[3],
|
||||
.dst_nb2 = dst->nb[2],
|
||||
.dst_nb3 = dst->nb[3],
|
||||
};
|
||||
ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params);
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
|
||||
m_hmx, k, n, act_stride, wgt_stride);
|
||||
}
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
m_hmx, k, n, (int) src0->type);
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
|
||||
return op_matmul(octx);
|
||||
}
|
||||
|
||||
// --- Phase 2: HVX on the remaining m_tail rows ---
|
||||
if (m_tail > 0) {
|
||||
// copy of src1 and dst
|
||||
struct htp_tensor src1_tail = *src1;
|
||||
struct htp_tensor dst_tail = *dst;
|
||||
|
||||
src1_tail.ne[1] = m_tail; // only tail rows
|
||||
dst_tail.ne[1] = m_tail; // only tail rows
|
||||
|
||||
// Offset activation and dst pointers past the HMX-processed rows.
|
||||
// Use nb[1] (row stride in bytes) to compute the byte offset.
|
||||
src1_tail.data += (uint32_t) m_hmx * src1->nb[1];
|
||||
dst_tail.data += (uint32_t) m_hmx * dst->nb[1];
|
||||
|
||||
octx->src[1] = &src1_tail;
|
||||
octx->dst = &dst_tail;
|
||||
|
||||
FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data);
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
return 0;
|
||||
#endif // HTP_HAS_HMX
|
||||
}
|
||||
|
||||
int op_matmul_id(struct htp_ops_context * octx) {
|
||||
htp_matmul_tensors_preamble;
|
||||
|
||||
@@ -2950,7 +3089,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
||||
struct htp_matmul_context * mmctx = &mmctx_struct;
|
||||
mmctx->octx = octx;
|
||||
|
||||
struct htp_tensor * restrict ids = &octx->src2;
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
@@ -3003,11 +3142,17 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
// Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops.
|
||||
octx->src1_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
||||
|
||||
octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL;
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->src2_spad.src = NULL;
|
||||
octx->dst_spad.src = NULL;
|
||||
|
||||
octx->src0_spad.stride = src0_row_size_padded;
|
||||
octx->src1_spad.stride = src1_row_size;
|
||||
|
||||
@@ -3031,20 +3176,18 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
||||
}
|
||||
}
|
||||
|
||||
// Setup worker pool callbacks
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)
|
||||
return HTP_STATUS_OK;
|
||||
|
||||
if (octx->src1_spad.src != src1) {
|
||||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
octx->src1_spad.src = src1;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
const uint32_t n_matmul_jobs = octx->n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
|
||||
}
|
||||
const uint32_t n_matmul_jobs = octx->n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_repeat_context {
|
||||
@@ -32,8 +32,8 @@ struct htp_repeat_context {
|
||||
static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data;
|
||||
struct htp_ops_context * octx = rctx->octx;
|
||||
const struct htp_tensor * src = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t ne00 = src->ne[0];
|
||||
const uint32_t ne01 = src->ne[1];
|
||||
@@ -98,8 +98,8 @@ static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * dat
|
||||
}
|
||||
|
||||
int op_repeat(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
// Validate that dst dims are multiples of src dims
|
||||
if (dst->ne[0] % src0->ne[0] != 0 ||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h
|
||||
@@ -253,10 +253,10 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_rope_context * rctx = (struct htp_rope_context *) data;
|
||||
struct htp_ops_context * octx = rctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * src2 = octx->src[2];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
@@ -284,7 +284,7 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
|
||||
const float * freq_factors = src2 ? (const float *) src2->data : NULL;
|
||||
|
||||
uint32_t ir = 0;
|
||||
uint32_t prev_i2 = (uint32_t) -1;
|
||||
@@ -384,10 +384,10 @@ done:
|
||||
static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * src2 = octx->src[2];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const char * op_type = "rope-f32";
|
||||
|
||||
@@ -424,19 +424,16 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
// Assign sizes
|
||||
octx->src0_spad.size_per_thread = src0_spad_per_thread;
|
||||
octx->dst_spad.size_per_thread = dst_spad_per_thread;
|
||||
octx->src0_spad.size = n_threads * src0_spad_per_thread;
|
||||
octx->dst_spad.size = n_threads * dst_spad_per_thread;
|
||||
octx->src1_spad.size = 0;
|
||||
|
||||
// Assign pointers
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = NULL;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = NULL; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
// Fill context
|
||||
struct htp_rope_context rctx;
|
||||
memset(&rctx, 0, sizeof(struct htp_rope_context));
|
||||
|
||||
@@ -483,7 +480,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
int op_rope(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_rope_f32(octx);
|
||||
break;
|
||||
|
||||
@@ -14,33 +14,37 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define set_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
const uint32_t ne02 = octx->src0.ne[2]; \
|
||||
const uint32_t ne03 = octx->src0.ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src1.ne[0]; \
|
||||
const uint32_t ne11 = octx->src1.ne[1]; \
|
||||
const uint32_t ne12 = octx->src1.ne[2]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src0.nb[1]; \
|
||||
const uint32_t nb02 = octx->src0.nb[2]; \
|
||||
const uint32_t nb03 = octx->src0.nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src1.nb[0]; \
|
||||
const uint32_t nb11 = octx->src1.nb[1]; \
|
||||
const uint32_t nb12 = octx->src1.nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst.nb[1]; \
|
||||
const uint32_t nb2 = octx->dst.nb[2]; \
|
||||
const uint32_t nb3 = octx->dst.nb[3]; \
|
||||
\
|
||||
const uint32_t ne1 = octx->dst.ne[1]; \
|
||||
\
|
||||
#define set_rows_preamble \
|
||||
const uint32_t ne00 = octx->src[0]->ne[0]; \
|
||||
const uint32_t ne01 = octx->src[0]->ne[1]; \
|
||||
const uint32_t ne02 = octx->src[0]->ne[2]; \
|
||||
const uint32_t ne03 = octx->src[0]->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = octx->src[1]->ne[0]; \
|
||||
const uint32_t ne11 = octx->src[1]->ne[1]; \
|
||||
const uint32_t ne12 = octx->src[1]->ne[2]; \
|
||||
const uint32_t ne13 = octx->src[1]->ne[3]; \
|
||||
\
|
||||
const uint32_t nb01 = octx->src[0]->nb[1]; \
|
||||
const uint32_t nb02 = octx->src[0]->nb[2]; \
|
||||
const uint32_t nb03 = octx->src[0]->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = octx->src[1]->nb[0]; \
|
||||
const uint32_t nb11 = octx->src[1]->nb[1]; \
|
||||
const uint32_t nb12 = octx->src[1]->nb[2]; \
|
||||
\
|
||||
const uint32_t nb1 = octx->dst->nb[1]; \
|
||||
const uint32_t nb2 = octx->dst->nb[2]; \
|
||||
const uint32_t nb3 = octx->dst->nb[3]; \
|
||||
\
|
||||
const uint32_t ne0 = octx->dst->ne[0]; \
|
||||
const uint32_t ne1 = octx->dst->ne[1]; \
|
||||
const uint32_t ne2 = octx->dst->ne[2]; \
|
||||
const uint32_t ne3 = octx->dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
struct htp_set_rows_context {
|
||||
@@ -56,12 +60,14 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
@@ -70,7 +76,7 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
@@ -78,14 +84,18 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
|
||||
continue;
|
||||
}
|
||||
|
||||
const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
const uintptr_t src0_ptr = octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
const uintptr_t dst_ptr = octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
// copy row
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "set-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
|
||||
}
|
||||
|
||||
static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
@@ -94,12 +104,14 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
@@ -108,7 +120,7 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
|
||||
if (i1 >= ne1) {
|
||||
@@ -116,13 +128,17 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
const uint8_t* src0_ptr = (const uint8_t *) octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03;
|
||||
uint8_t* dst_ptr = (uint8_t *) octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "set-rows-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
|
||||
}
|
||||
|
||||
int op_set_rows(struct htp_ops_context * octx) {
|
||||
@@ -130,15 +146,15 @@ int op_set_rows(struct htp_ops_context * octx) {
|
||||
|
||||
const uint32_t n_threads = MIN(nr, octx->n_threads);
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
if (octx->src[0]->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
|
||||
if (octx->dst->type != HTP_TYPE_F32 && octx->dst->type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
|
||||
if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
@@ -153,7 +169,7 @@ int op_set_rows(struct htp_ops_context * octx) {
|
||||
|
||||
srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
|
||||
|
||||
switch(octx->dst.type) {
|
||||
switch(octx->dst->type) {
|
||||
case HTP_TYPE_F32:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads);
|
||||
break;
|
||||
|
||||
@@ -15,68 +15,89 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define htp_softmax_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \
|
||||
const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \
|
||||
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \
|
||||
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \
|
||||
\
|
||||
const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \
|
||||
const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \
|
||||
const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \
|
||||
const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
#define htp_softmax_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1 ? src1->ne[0] : 1; \
|
||||
const uint32_t ne11 = src1 ? src1->ne[1] : 1; \
|
||||
const uint32_t ne12 = src1 ? src1->ne[2] : 1; \
|
||||
const uint32_t ne13 = src1 ? src1->ne[3] : 1; \
|
||||
\
|
||||
const uint32_t nb10 = src1 ? src1->nb[0] : 1; \
|
||||
const uint32_t nb11 = src1 ? src1->nb[1] : 1; \
|
||||
const uint32_t nb12 = src1 ? src1->nb[2] : 1; \
|
||||
const uint32_t nb13 = src1 ? src1->nb[3] : 1; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct htp_softmax_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
bool use_f16;
|
||||
bool use_src1;
|
||||
|
||||
uint32_t n_head;
|
||||
uint32_t n_head_log2;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
struct fastdiv_values fastdiv_ne01;
|
||||
struct fastdiv_values fastdiv_ne02;
|
||||
struct fastdiv_values fastdiv_ne12; // For mask broadcasting
|
||||
struct fastdiv_values fastdiv_ne13; // For mask broadcasting
|
||||
size_t spad_stride;
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
};
|
||||
|
||||
static void apply_mask(float * restrict wp0,
|
||||
const float * restrict mp_f32,
|
||||
const __fp16 * restrict mp_f16,
|
||||
uint32_t ne00,
|
||||
float slope,
|
||||
bool use_f16) {
|
||||
if (!mp_f32) {
|
||||
return;
|
||||
}
|
||||
if (use_f16) {
|
||||
for (uint32_t i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (uint32_t i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
|
||||
memset(smctx, 0, sizeof(struct htp_softmax_context));
|
||||
|
||||
memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
|
||||
smctx->n_head = src0->ne[2];
|
||||
@@ -85,8 +106,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_
|
||||
smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
|
||||
smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
|
||||
|
||||
smctx->use_src1 = (src1->ne[0] != 0);
|
||||
smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||
smctx->use_src1 = (src1 != 0);
|
||||
smctx->use_f16 = (src1 != 0) && (src1->type == HTP_TYPE_F16);
|
||||
|
||||
smctx->octx = octx;
|
||||
|
||||
@@ -97,8 +118,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_
|
||||
if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
|
||||
if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
|
||||
|
||||
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
|
||||
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
|
||||
const uint32_t ne12 = src1 ? src1->ne[2] : 1;
|
||||
const uint32_t ne13 = src1 ? src1->ne[3] : 1;
|
||||
|
||||
if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
|
||||
if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
|
||||
@@ -139,10 +160,7 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict pad,
|
||||
const int num_elems) {
|
||||
static void hvx_fast_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict pad, const int num_elems) {
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_pad = (HVX_Vector *) pad;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
@@ -188,27 +206,20 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static float hvx_softmax_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const int num_elems,
|
||||
const float max) {
|
||||
static float hvx_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict spad, const int num_elems, const float max) {
|
||||
hvx_sub_scalar_f32(spad, src, max, num_elems);
|
||||
|
||||
hvx_exp_f32(dst, spad, num_elems, false);
|
||||
|
||||
float sum = hvx_reduce_sum_f32(dst, num_elems);
|
||||
|
||||
return sum;
|
||||
return hvx_reduce_sum_f32(dst, num_elems);
|
||||
}
|
||||
|
||||
static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
|
||||
struct htp_ops_context * octx = smctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
htp_softmax_preamble3;
|
||||
|
||||
@@ -223,22 +234,26 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
uint64_t qt = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
|
||||
if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
// Only use the fast path when aligned AND row size is multiple of VLEN (128 bytes)
|
||||
// The fast path (hvx_fast_softmax_f32) doesn't handle tail elements
|
||||
// The non-opt path uses hvx_softmax_f32 which properly handles all sizes via its helper functions
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
||||
|
||||
float * wp0 = (float *) src0_spad_data;
|
||||
float * wp1 = (float *) src1_spad_data;
|
||||
@@ -278,47 +293,29 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
// ALiBi
|
||||
if (i2 != prev_i2) {
|
||||
const uint32_t h = i2; // head
|
||||
|
||||
slope = (smctx->max_bias > 0.0f) ?
|
||||
h < smctx->n_head_log2 ?
|
||||
powf(smctx->m0, h + 1) :
|
||||
powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
|
||||
1.0f;
|
||||
slope = (smctx->max_bias > 0.0f) ? h < smctx->n_head_log2 ? powf(smctx->m0, h + 1) : powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : 1.0f;
|
||||
prev_i2 = i2;
|
||||
}
|
||||
|
||||
float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
|
||||
float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
||||
float * sp = (float *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03);
|
||||
float * dp = (float *) ((char *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
__fp16 * mp_f16 = (smctx->use_src1) ?
|
||||
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
float * mp_f32 = (smctx->use_src1) ?
|
||||
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
__fp16 * mp_f16 = (smctx->use_src1) ? (__fp16 *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL;
|
||||
float * mp_f32 = (smctx->use_src1) ? (float *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL;
|
||||
|
||||
if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, (const uint8_t *) mp_f32, slope);
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else if (1 == opt_path) {
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
|
||||
if (mp_f32) {
|
||||
if (smctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16);
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
// Non-optimized path: uses HVX helper functions that properly handle all tensor sizes
|
||||
// including non-multiples of 32 (the HVX vector lane count for f32)
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
|
||||
apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16);
|
||||
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
@@ -326,54 +323,47 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
|
||||
FARF(HIGH, "softmax-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u : opt %u f16 %u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
ne0, ne1, ne2, ne3, opt_path, smctx->use_f16, (unsigned) qt);
|
||||
}
|
||||
|
||||
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * src1 = octx->src[1];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
struct htp_softmax_context smctx;
|
||||
const char * op_type = "softmax-f32";
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_SOFTMAX:
|
||||
init_softmax_ctx(&smctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
init_softmax_ctx(&smctx, octx);
|
||||
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src0_row_size;
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
// 4 rows per thread, padded to HVX vector size
|
||||
octx->src0_spad.size_per_thread = hex_round_up(4 * src0_row_size, 128);
|
||||
octx->src1_spad.size_per_thread = hex_round_up(4 * src1_row_size, 128);
|
||||
octx->dst_spad.size_per_thread = hex_round_up(4 * dst_row_size, 128);
|
||||
|
||||
// Use stride for calculating offset
|
||||
smctx.spad_stride = hex_round_up(src0_row_size, 128);
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads;
|
||||
octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads;
|
||||
octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
if (src1) {
|
||||
FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
@@ -385,19 +375,17 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);
|
||||
}
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return err;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);
|
||||
|
||||
return err;
|
||||
}
|
||||
@@ -405,7 +393,7 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
int op_softmax(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_softmax_f32(octx);
|
||||
break;
|
||||
|
||||
@@ -16,14 +16,14 @@
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "hex-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define htp_ssm_conv_tensors_preamble \
|
||||
struct htp_tensor * restrict src0 = &octx->src0; \
|
||||
struct htp_tensor * restrict src1 = &octx->src1; \
|
||||
struct htp_tensor * restrict dst = &octx->dst; \
|
||||
#define htp_ssm_conv_tensors_preamble \
|
||||
const struct htp_tensor * restrict src0 = octx->src[0]; \
|
||||
const struct htp_tensor * restrict src1 = octx->src[1]; \
|
||||
const struct htp_tensor * restrict dst = octx->dst; \
|
||||
struct htp_spad * restrict src0_spad = &octx->src0_spad; \
|
||||
struct htp_spad * restrict src1_spad = &octx->src1_spad; \
|
||||
struct htp_spad * restrict dst_spad = &octx->dst_spad; \
|
||||
@@ -289,9 +289,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
// Compute gather scratchpad size for src0 and src1
|
||||
const size_t gather_spad_size = n_threads * VLEN * 2;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL;
|
||||
|
||||
FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n",
|
||||
gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread,
|
||||
@@ -323,8 +323,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
}
|
||||
|
||||
int op_ssm_conv(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (dst->type) {
|
||||
case HTP_TYPE_F32:
|
||||
|
||||
@@ -14,13 +14,13 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define sum_rows_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0;\
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
\
|
||||
#define sum_rows_preamble \
|
||||
const struct htp_tensor *src0 = octx->src[0]; \
|
||||
const struct htp_tensor *dst = octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
@@ -94,7 +94,7 @@ static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data)
|
||||
int op_sum_rows(struct htp_ops_context * octx) {
|
||||
sum_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
if (octx->src[0]->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_unary_context {
|
||||
@@ -267,8 +267,8 @@ static void softplus_f32(const float * restrict src,
|
||||
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
||||
struct htp_ops_context * octx = uctx->octx;
|
||||
const struct htp_tensor * src = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
htp_unary_preamble;
|
||||
|
||||
@@ -387,8 +387,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * src0 = octx->src[0];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const char * op_type = NULL;
|
||||
|
||||
@@ -490,7 +490,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
int op_unary(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
switch (octx->src[0]->type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_unary_f32(octx);
|
||||
break;
|
||||
|
||||
@@ -90,6 +90,8 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_q4_1_f32_flat
|
||||
mul_mv_q4_k_f32
|
||||
mul_mv_q4_k_f32_flat
|
||||
mul_mv_q5_k_f32
|
||||
mul_mv_q5_k_f32_flat
|
||||
mul_mv_q6_k_f32
|
||||
mul_mv_q6_k_f32_flat
|
||||
mul_mv_q8_0_f32
|
||||
@@ -109,6 +111,7 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_q4_1_f32_l4_lm
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
mul_mm_q4_k_f32_l4_lm
|
||||
mul_mm_q5_k_f32_l4_lm
|
||||
mul_mm_q6_k_f32_l4_lm
|
||||
mul_mm_q8_0_f32_8x4
|
||||
gemv_noshuffle_q4_1_f32
|
||||
|
||||
@@ -541,12 +541,15 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_convert_block_q4_K_noshuffle;
|
||||
cl_kernel kernel_restore_block_q4_K_noshuffle;
|
||||
cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K;
|
||||
cl_kernel kernel_convert_block_q5_K, kernel_restore_block_q5_K;
|
||||
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
|
||||
cl_kernel kernel_mul_mv_q4_1_f32;
|
||||
cl_kernel kernel_mul_mv_q4_1_f32_flat;
|
||||
cl_kernel kernel_mul_mv_q4_K_f32;
|
||||
cl_kernel kernel_mul_mv_q4_K_f32_flat;
|
||||
cl_kernel kernel_mul_mv_q5_K_f32;
|
||||
cl_kernel kernel_mul_mv_q5_K_f32_flat;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32_flat;
|
||||
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
|
||||
@@ -587,6 +590,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mm_q4_1_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q4_k_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q5_k_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
|
||||
|
||||
std::vector<ProfilingInfo> profiling_info;
|
||||
@@ -938,6 +942,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err));
|
||||
@@ -1249,6 +1255,39 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q5_k_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_q5_k_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_q5_k_f32.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q5_k_f32_flat
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_q5_k_f32_flat.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_q5_k_f32_flat.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32_flat", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
}
|
||||
|
||||
// mul_mv_q6_k_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -1556,6 +1595,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q5_k_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_q5_k_f32_l4_lm.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_q5_k_f32_l4_lm.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_k_f32_l4_lm", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_f16_f32_kq_kqv
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -3530,6 +3586,58 @@ struct ggml_tensor_extra_cl_q4_K {
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_q5_K {
|
||||
// Lower 4 bits of quantized weights.
|
||||
cl_mem q = nullptr;
|
||||
// Upper 1 bit of quantized weights.
|
||||
cl_mem qh = nullptr;
|
||||
// Scales for each block.
|
||||
cl_mem s = nullptr;
|
||||
// Scales for each super block.
|
||||
cl_mem d = nullptr;
|
||||
// Min for each super block.
|
||||
cl_mem dm = nullptr;
|
||||
|
||||
size_t size_q = 0;
|
||||
size_t size_qh = 0;
|
||||
size_t size_s = 0;
|
||||
size_t size_d = 0;
|
||||
size_t size_dm = 0;
|
||||
|
||||
~ggml_tensor_extra_cl_q5_K() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (q != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(q));
|
||||
q = nullptr;
|
||||
}
|
||||
if (qh != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(qh));
|
||||
qh = nullptr;
|
||||
}
|
||||
if (s != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(s));
|
||||
s = nullptr;
|
||||
}
|
||||
if (d != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(d));
|
||||
d = nullptr;
|
||||
}
|
||||
if (dm != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(dm));
|
||||
dm = nullptr;
|
||||
}
|
||||
|
||||
size_q = 0;
|
||||
size_qh = 0;
|
||||
size_s = 0;
|
||||
size_d = 0;
|
||||
size_dm = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_q6_K {
|
||||
// Lower 4 bits of quantized weights.
|
||||
cl_mem ql = nullptr;
|
||||
@@ -3945,6 +4053,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_MXFP4 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q6_K) {
|
||||
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
@@ -4153,6 +4262,12 @@ struct ggml_backend_opencl_buffer_context {
|
||||
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
|
||||
delete e;
|
||||
}
|
||||
for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K) {
|
||||
delete e;
|
||||
}
|
||||
for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) {
|
||||
delete e;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
|
||||
@@ -4245,6 +4360,21 @@ struct ggml_backend_opencl_buffer_context {
|
||||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_q5_K * ggml_opencl_alloc_temp_tensor_extra_q5_K() {
|
||||
ggml_tensor_extra_cl_q5_K * extra;
|
||||
if (temp_tensor_extras_q5_K.empty()) {
|
||||
extra = new ggml_tensor_extra_cl_q5_K();
|
||||
} else {
|
||||
extra = temp_tensor_extras_q5_K.back();
|
||||
temp_tensor_extras_q5_K.pop_back();
|
||||
}
|
||||
|
||||
temp_tensor_extras_q5_K_in_use.push_back(extra);
|
||||
|
||||
extra->reset();
|
||||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() {
|
||||
ggml_tensor_extra_cl_q6_K * extra;
|
||||
if (temp_tensor_extras_q6_K.empty()) {
|
||||
@@ -4291,6 +4421,11 @@ struct ggml_backend_opencl_buffer_context {
|
||||
}
|
||||
temp_tensor_extras_q4_K_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) {
|
||||
temp_tensor_extras_q5_K.push_back(e);
|
||||
}
|
||||
temp_tensor_extras_q5_K_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
|
||||
temp_tensor_extras_q6_K.push_back(e);
|
||||
}
|
||||
@@ -4314,6 +4449,8 @@ struct ggml_backend_opencl_buffer_context {
|
||||
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q4_K *> temp_tensor_extras_q4_K;
|
||||
std::vector<ggml_tensor_extra_cl_q4_K *> temp_tensor_extras_q4_K_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q5_K *> temp_tensor_extras_q5_K;
|
||||
std::vector<ggml_tensor_extra_cl_q5_K *> temp_tensor_extras_q5_K_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K;
|
||||
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K_in_use;
|
||||
|
||||
@@ -5152,6 +5289,97 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q5_K) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
|
||||
// Allocate the new extra and create aliases from the original.
|
||||
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||
ggml_tensor_extra_cl_q5_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_K();
|
||||
|
||||
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
|
||||
size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/8;
|
||||
size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3*ggml_blck_size(tensor->type)/64);
|
||||
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
GGML_ASSERT(size_q + size_qh + size_s + size_d + size_dm == ggml_nbytes(tensor) &&
|
||||
"Incorrect tensor size");
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device;
|
||||
CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err));
|
||||
CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL));
|
||||
|
||||
cl_buffer_region region;
|
||||
|
||||
// Create subbuffer for d.
|
||||
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
|
||||
region.size = size_d;
|
||||
extra->d = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
auto previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for dm.
|
||||
region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
|
||||
region.size = size_dm;
|
||||
extra->dm = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for s.
|
||||
region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment);
|
||||
region.size = size_s;
|
||||
extra->s = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for q (lower 4 bits)
|
||||
region.origin = align_to(previous_origin + size_s, backend_ctx->alignment);
|
||||
region.size = size_q;
|
||||
extra->q = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for qh (upper 1 bit)
|
||||
region.origin = align_to(previous_origin + size_q, backend_ctx->alignment);
|
||||
region.size = size_qh;
|
||||
CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->dm));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
|
||||
extra->size_q = size_q;
|
||||
extra->size_qh = size_qh;
|
||||
extra->size_s = size_s;
|
||||
extra->size_d = size_d;
|
||||
extra->size_dm = size_dm;
|
||||
|
||||
tensor->extra = extra;
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q6_K) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
@@ -5658,6 +5886,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q5_K) {
|
||||
ggml_tensor_extra_cl_q5_K * extra = (ggml_tensor_extra_cl_q5_K *)tensor->extra;
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clEnqueueReadBuffer(
|
||||
queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q6_K) {
|
||||
ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra;
|
||||
|
||||
@@ -10221,6 +10478,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra;
|
||||
ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra;
|
||||
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
|
||||
#endif
|
||||
|
||||
@@ -10925,6 +11183,51 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q5_K: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
int batch_stride_a = ne00*ne01;
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_a
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10)); // stride_b
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne01)); // stride_d
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_a));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_b));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &batch_stride_d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3));
|
||||
|
||||
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
|
||||
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q6_K: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
@@ -11442,7 +11745,81 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q5_K: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_q5_K_f32_flat;
|
||||
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 2;
|
||||
ndst = 16;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3));
|
||||
#else
|
||||
kernel = backend_ctx->kernel_mul_mv_q5_K_f32;
|
||||
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_Q6_K:
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat;
|
||||
@@ -11610,7 +11987,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
} else if (src0t == GGML_TYPE_Q3_K) {
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
} else if (src0t == GGML_TYPE_Q5_K) {
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
} else if (src0t == GGML_TYPE_Q6_K) {
|
||||
size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
|
||||
|
||||
@@ -66,6 +66,17 @@ struct block_q4_K {
|
||||
uchar q[QK_K / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q5_k
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q5_K {
|
||||
half d; // delta
|
||||
half dm; // min
|
||||
uchar s[K_SCALE_SIZE];
|
||||
uchar qh[QK_K / 8];
|
||||
uchar qs[QK_K / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q6_K
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -546,6 +557,71 @@ kernel void kernel_restore_block_q4_K_noshuffle(
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q5_K
|
||||
// Convert the block_q5_K format to 5 separate arrays (AOS -> SOA).
|
||||
// Each thread processes a super block.
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_convert_block_q5_K(
|
||||
global struct block_q5_K * src0,
|
||||
global uchar * dst_q,
|
||||
global uchar * dst_qh,
|
||||
global uchar * dst_s,
|
||||
global half * dst_d,
|
||||
global half * dst_dm
|
||||
) {
|
||||
global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0);
|
||||
global uchar * qh = (global uchar *) dst_qh + QK_K/8*get_global_id(0);
|
||||
global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
global half * dm = (global half *) dst_dm + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
*dm = b->dm;
|
||||
|
||||
for (int i = 0; i < QK_K/2; ++i) {
|
||||
q[i] = b->qs[i];
|
||||
}
|
||||
for (int i = 0; i < QK_K/8; ++i) {
|
||||
qh[i] = b->qh[i];
|
||||
}
|
||||
for (int i = 0; i < K_SCALE_SIZE; ++i) {
|
||||
s[i] = b->s[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Restore block_q5_K from flattened arrays.
|
||||
// Each thread processes a super block.
|
||||
kernel void kernel_restore_block_q5_K(
|
||||
global uchar * src_q,
|
||||
global uchar * src_qh,
|
||||
global uchar * src_s,
|
||||
global half * src_d,
|
||||
global half * src_dm,
|
||||
global struct block_q5_K * dst
|
||||
) {
|
||||
global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0);
|
||||
global uchar * qh = (global uchar *) src_qh + QK_K/8*get_global_id(0);
|
||||
global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
global half * dm = (global half *) src_dm + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
b->dm = *dm;
|
||||
|
||||
for (int i = 0; i < QK_K/2; ++i) {
|
||||
b->qs[i] = q[i];
|
||||
}
|
||||
for (int i = 0; i < QK_K/8; ++i) {
|
||||
b->qh[i] = qh[i];
|
||||
}
|
||||
for (int i = 0; i < K_SCALE_SIZE; ++i) {
|
||||
b->s[i] = s[i];
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q6_K
|
||||
// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 4
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q5_k_f32_l4_lm(
|
||||
global uchar4 * src0_q,
|
||||
global uchar * src0_qh,
|
||||
global uchar * src0_s,
|
||||
global half * src0_d,
|
||||
global half * src0_dm,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 64;
|
||||
int iqs = (idx % 64) * 2;
|
||||
|
||||
int n = iqs / 32;
|
||||
int b = (iqs % 32) / 16;
|
||||
int is = 2 * n + b;
|
||||
int qsi = n * 32 + (iqs % 16) * 2;
|
||||
|
||||
global uchar * scales = src0_s + ib * 12;
|
||||
|
||||
int scidx0 = (is < 4) ? is : (is + 4);
|
||||
int scidx1 = (is < 4) ? is : (is - 4);
|
||||
int scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
int scidxshift1 = (is < 4) ? 0 : 2;
|
||||
int mbidx0 = is + 4;
|
||||
int mbidx1 = (is < 4) ? is + 4 : is;
|
||||
int mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
int mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
int mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
int mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1);
|
||||
uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1);
|
||||
|
||||
float d = (float)src0_d[ib] * (float)sc;
|
||||
float m = -(float)src0_dm[ib] * (float)mbyte;
|
||||
|
||||
int qh_base = (iqs % 16) * 2;
|
||||
int bit_pos = 2*n + b;
|
||||
uchar h0 = (src0_qh[ib*32 + qh_base + 0] >> bit_pos) & 1;
|
||||
uchar h1 = (src0_qh[ib*32 + qh_base + 1] >> bit_pos) & 1;
|
||||
uchar h2 = (src0_qh[ib*32 + qh_base + 2] >> bit_pos) & 1;
|
||||
uchar h3 = (src0_qh[ib*32 + qh_base + 3] >> bit_pos) & 1;
|
||||
|
||||
global uchar4 * qs = src0_q + ib*32 + (qsi >> 2);
|
||||
uchar4 q = *qs;
|
||||
float4 v1 = (convert_float4((uchar4)(
|
||||
((q.s0 >> (b * 4))&0x0F) | (h0 << 4),
|
||||
((q.s1 >> (b * 4))&0x0F) | (h1 << 4),
|
||||
((q.s2 >> (b * 4))&0x0F) | (h2 << 4),
|
||||
((q.s3 >> (b * 4))&0x0F) | (h3 << 4)
|
||||
)))*d + m;
|
||||
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
half dmin; // super-block scale for quantized mins
|
||||
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte)
|
||||
uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte)
|
||||
} block_q5_K;
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 16
|
||||
#elif defined(ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q5_K_f32(
|
||||
global char * src0,
|
||||
int offset0,
|
||||
global char * src1,
|
||||
int offset1,
|
||||
global char * dst,
|
||||
int offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
ushort kmask1 = 0x3f3f;
|
||||
ushort kmask2 = 0x0f0f;
|
||||
ushort kmask3 = 0xc0c0;
|
||||
|
||||
int ix = get_sub_group_local_id()/8; // super block index
|
||||
int it = get_sub_group_local_id()%8; // block index (inside super block)
|
||||
int iq = it/4; // 0 or 1 - first or second half of the super block
|
||||
int ir = it%4; // 0...3 - block index in the half super block
|
||||
|
||||
int nb = ne00/QK_K;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global block_q5_K * x = (global block_q5_K *) (src0 + offset_src0);
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
float sumf[N_DST] = {0.f};
|
||||
float all_sum;
|
||||
|
||||
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
||||
|
||||
uchar u1_lo = (uchar)(1 << (2*iq));
|
||||
uchar u2_lo = (uchar)(2 << (2*iq));
|
||||
uchar u1_hi = (uchar)(1 << (2*iq + 4));
|
||||
uchar u2_hi = (uchar)(2 << (2*iq + 4));
|
||||
|
||||
ushort sc16[4];
|
||||
uchar * sc8 = (uchar *)sc16;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+0] = y4[i+0];
|
||||
sumy.s0 += yl[i+0];
|
||||
|
||||
yl[i+8] = y4[i+32];
|
||||
sumy.s1 += yl[i+8];
|
||||
|
||||
yh[i+0] = y4[i+128];
|
||||
sumy.s2 += yh[i+0];
|
||||
|
||||
yh[i+8] = y4[i+160];
|
||||
sumy.s3 += yh[i+8];
|
||||
}
|
||||
|
||||
global ushort * sc = (global ushort *)x[ib].scales + iq;
|
||||
global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
|
||||
global uchar * qh = x[ib].qh + 8 * ir;
|
||||
global half * dh = &x[ib].d;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
sc16[0] = sc[0] & kmask1;
|
||||
sc16[1] = sc[2] & kmask1;
|
||||
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
||||
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
||||
|
||||
global ushort * q2 = q1 + 32;
|
||||
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f));
|
||||
acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f));
|
||||
acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f));
|
||||
acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f));
|
||||
acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f));
|
||||
acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f));
|
||||
acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f));
|
||||
acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f));
|
||||
}
|
||||
|
||||
float dall = dh[0];
|
||||
float dmin = dh[1];
|
||||
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
|
||||
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
|
||||
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
|
||||
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
|
||||
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
|
||||
|
||||
q1 += nb01/2;
|
||||
sc += nb01/2;
|
||||
dh += nb01/2;
|
||||
qh += nb01;
|
||||
}
|
||||
|
||||
y4 += BLOCK_STRIDE * QK_K;
|
||||
}
|
||||
|
||||
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = sub_group_reduce_add(sumf[row]);
|
||||
if (first_row + row < ne01) {
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q5_K
|
||||
//------------------------------------------------------------------------------
|
||||
#define QK_K 256
|
||||
#define BLOCK_Q5K_SIZE 176
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
half dmin; // super-block scale for quantized mins
|
||||
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte)
|
||||
uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte)
|
||||
} block_q5_K;
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 16
|
||||
#elif defined(ADRENO_GPU)
|
||||
#define N_DST 16
|
||||
#define N_SIMDGROUP 2
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
#undef BLOCK_STRIDE
|
||||
// number of (super) blocks each subgroup processes
|
||||
// each thread in a subgroup processes a block (32 weights)
|
||||
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q5_K_f32_flat(
|
||||
global uchar * src0_q,
|
||||
global uchar * src0_qh,
|
||||
global uchar * src0_s,
|
||||
global half * src0_d,
|
||||
global half * src0_dm,
|
||||
global char * src1,
|
||||
int offset1,
|
||||
global char * dst,
|
||||
int offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
ushort kmask1 = 0x3f3f;
|
||||
ushort kmask2 = 0x0f0f;
|
||||
ushort kmask3 = 0xc0c0;
|
||||
|
||||
int ix = get_sub_group_local_id()/8;
|
||||
int it = get_sub_group_local_id()%8;
|
||||
int iq = it/4;
|
||||
int ir = it%4;
|
||||
|
||||
int nb = ne00/QK_K;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q5K_SIZE;
|
||||
uint blk = nb01 / BLOCK_Q5K_SIZE;
|
||||
global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2);
|
||||
global uchar * blk_qh = (global uchar *)src0_qh + offset_src0*(QK_K/8);
|
||||
global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE;
|
||||
global half * blk_d = (global half *)src0_d + offset_src0;
|
||||
global half * blk_dm = (global half *)src0_dm + offset_src0;
|
||||
|
||||
int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13;
|
||||
global float * y = (global float *)(src1 + offset_src1);
|
||||
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
float sumf[N_DST] = {0.f};
|
||||
float all_sum;
|
||||
|
||||
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
||||
|
||||
uchar u1_lo = (uchar)(1 << (2*iq));
|
||||
uchar u2_lo = (uchar)(2 << (2*iq));
|
||||
uchar u1_hi = (uchar)(1 << (2*iq + 4));
|
||||
uchar u2_hi = (uchar)(2 << (2*iq + 4));
|
||||
|
||||
ushort sc16[4];
|
||||
uchar * sc8 = (uchar *)sc16;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+0] = y4[i+0];
|
||||
sumy.s0 += yl[i+0];
|
||||
|
||||
yl[i+8] = y4[i+32];
|
||||
sumy.s1 += yl[i+8];
|
||||
|
||||
yh[i+0] = y4[i+128];
|
||||
sumy.s2 += yh[i+0];
|
||||
|
||||
yh[i+8] = y4[i+160];
|
||||
sumy.s3 += yh[i+8];
|
||||
}
|
||||
|
||||
global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir);
|
||||
global uchar * qh = (global uchar *)(blk_qh + ib * (QK_K/8)) + 8 * ir;
|
||||
global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq;
|
||||
global half * d = blk_d + ib;
|
||||
global half * dm = blk_dm + ib;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
sc16[0] = sc[0] & kmask1;
|
||||
sc16[1] = sc[2] & kmask1;
|
||||
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
||||
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
||||
|
||||
global ushort * q2 = q1 + 32;
|
||||
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f));
|
||||
acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f));
|
||||
acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f));
|
||||
acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f));
|
||||
acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f));
|
||||
acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f));
|
||||
acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f));
|
||||
acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f));
|
||||
}
|
||||
|
||||
float dall = *d;
|
||||
float dmin = *dm;
|
||||
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
|
||||
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
|
||||
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
|
||||
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
|
||||
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
|
||||
|
||||
q1 += blk*64;
|
||||
qh += blk*32;
|
||||
sc += blk*6;
|
||||
d += blk;
|
||||
dm += blk;
|
||||
}
|
||||
|
||||
y4 += BLOCK_STRIDE * QK_K;
|
||||
}
|
||||
|
||||
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = sub_group_reduce_add(sumf[row]);
|
||||
if (first_row + row < ne01) {
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -488,7 +488,7 @@ static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t
|
||||
const int nb = k / QK_NVFP4;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
dequantize_block_nvfp4(vx, y, k);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#define GGML_SYCL_DEQUANTIZE_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
#include "convert.hpp"
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
|
||||
|
||||
@@ -355,7 +355,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -176,14 +176,12 @@ static void launch_gated_delta_net(const float * q_d,
|
||||
const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1);
|
||||
const sycl::uint3 rq3_magic = init_fastdiv_values(rq3);
|
||||
|
||||
int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
|
||||
|
||||
switch (S_v) {
|
||||
case 16:
|
||||
{
|
||||
constexpr int sv = 16;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
|
||||
sb3, neqk1_magic, rq3_magic, scale);
|
||||
@@ -194,7 +192,7 @@ static void launch_gated_delta_net(const float * q_d,
|
||||
{
|
||||
constexpr int sv = 32;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
|
||||
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
|
||||
sb3, neqk1_magic, rq3_magic, scale);
|
||||
@@ -205,7 +203,7 @@ static void launch_gated_delta_net(const float * q_d,
|
||||
{
|
||||
constexpr int sv = 64;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
@@ -217,7 +215,7 @@ static void launch_gated_delta_net(const float * q_d,
|
||||
{
|
||||
constexpr int sv = 128;
|
||||
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_delta_net_sycl<sv, KDA>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
|
||||
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
|
||||
@@ -4727,12 +4727,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
struct ggml_tensor * a = op->src[0];
|
||||
struct ggml_tensor * b = op->src[1];
|
||||
|
||||
// disable Q1_0 until implementation
|
||||
if (a->type == GGML_TYPE_Q1_0 || b->type == GGML_TYPE_Q1_0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (a->ne[3] != b->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
|
||||
|
||||
|
||||
// TODO: The configuration below needs more work to be supported with oneDNN
|
||||
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
|
||||
a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
|
||||
|
||||
@@ -272,7 +272,7 @@ static void upscale_f32_sycl(const float * x,
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
|
||||
});
|
||||
}
|
||||
@@ -304,7 +304,7 @@ static void upscale_f32_bilinear_sycl(const float * x,
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32_bilinear_antialias(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst,
|
||||
ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
@@ -314,7 +314,7 @@ static void upscale_f32_bilinear_sycl(const float * x,
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32_bilinear(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst,
|
||||
ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
@@ -349,7 +349,7 @@ static void upscale_f32_bicubic_sycl(const float * x,
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
upscale_f32_bicubic(
|
||||
x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst,
|
||||
ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
|
||||
@@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params {
|
||||
}
|
||||
};
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
GGML_UNUSED(kv_type);
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
result.path = FA_SCALAR;
|
||||
@@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
||||
|
||||
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
||||
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
|
||||
result.block_rows /= 2;
|
||||
}
|
||||
|
||||
@@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (device->fp16) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
|
||||
}
|
||||
} else {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product && device->subgroup_clustered) {
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
|
||||
}
|
||||
}
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
@@ -8780,7 +8805,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
GGML_UNUSED(f32acc);
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t wg_size = params.workgroup_size;
|
||||
@@ -8789,21 +8814,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
|
||||
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
|
||||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
|
||||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
|
||||
|
||||
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
|
||||
|
||||
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
uint32_t Qf, kvsh, kblocksh_size;
|
||||
if (mmq) {
|
||||
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
|
||||
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
|
||||
Qf = Br * (hsk / 32) * block_b_size;
|
||||
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
// kvsh uses D = HSV (K goes through kblocksh instead)
|
||||
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
|
||||
// block_a_cache size depends on quant type
|
||||
uint32_t block_a_size;
|
||||
switch (kv_type) {
|
||||
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
|
||||
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
|
||||
default: block_a_size = 0; break;
|
||||
}
|
||||
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
|
||||
} else {
|
||||
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
kblocksh_size = 0;
|
||||
}
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,13 @@
|
||||
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef MMQ
|
||||
#extension GL_EXT_integer_dot_product : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
#endif
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
@@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
|
||||
const uint32_t masksh_stride = Br + 1;
|
||||
shared FLOAT_TYPE masksh[Bc * masksh_stride];
|
||||
|
||||
#ifndef MMQ
|
||||
const uint32_t qf_stride = HSK / 4 + 1;
|
||||
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
|
||||
#else
|
||||
|
||||
const uint32_t qf_stride = HSK / 32;
|
||||
shared block_b_cache Qf[Br * qf_stride];
|
||||
#endif
|
||||
|
||||
#ifndef MMQ
|
||||
const uint32_t D = HSK > HSV ? HSK : HSV;
|
||||
#else
|
||||
const uint32_t D = HSV;
|
||||
#endif
|
||||
const uint32_t kvsh_stride = D / 4 + 1;
|
||||
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
|
||||
|
||||
#ifdef MMQ
|
||||
|
||||
shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1];
|
||||
#endif
|
||||
|
||||
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
|
||||
|
||||
#ifdef MMQ
|
||||
#include "flash_attn_mmq_funcs.glsl"
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
@@ -82,10 +108,39 @@ void main() {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N;
|
||||
#ifndef MMQ
|
||||
if (is_in_bounds) {
|
||||
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
#else
|
||||
const uint buf_ib = r * qf_stride + d / 8;
|
||||
const uint buf_iqs = d % 8;
|
||||
|
||||
FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f);
|
||||
const FLOAT_TYPEV4 abs_vals = abs(vals);
|
||||
|
||||
const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
||||
const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8);
|
||||
const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0);
|
||||
const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0);
|
||||
vals = round(vals * qd_inv);
|
||||
|
||||
Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals));
|
||||
|
||||
#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0);
|
||||
}
|
||||
#else // Q4_0, Q4_1, Q5_0, Q5_1
|
||||
const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
|
||||
if (buf_iqs == 0) {
|
||||
Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
barrier();
|
||||
|
||||
@@ -195,6 +250,7 @@ void main() {
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
barrier();
|
||||
#ifndef MMQ
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
@@ -214,9 +270,29 @@ void main() {
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint ints_per_block = 8 / QUANT_R_MMQ;
|
||||
const uint quant_iters = Bc * HSK / 32 * ints_per_block;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) {
|
||||
const uint32_t iqs = (idx + tid) % ints_per_block;
|
||||
const uint32_t ib = (idx + tid) / ints_per_block;
|
||||
const uint32_t c = ib / (HSK / 32);
|
||||
const uint32_t block = ib % (HSK / 32);
|
||||
if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) {
|
||||
const uint buf_ib = c * qf_stride + block;
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
const uint global_ib = (j * Bc + c) * k_stride + block;
|
||||
k_block_to_shmem(buf_ib, global_ib, iqs, k_offset);
|
||||
} else {
|
||||
k_block_to_shmem_zero(buf_ib, iqs);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MMQ
|
||||
barrier();
|
||||
}
|
||||
|
||||
#ifndef MMQ
|
||||
// More d iterations means Q register caching becomes relevant
|
||||
// Few iterations means the additional registers needed are worse than the speed-up from caching
|
||||
if (HSK_per_thread / 4 > 4) {
|
||||
@@ -275,6 +351,110 @@ void main() {
|
||||
}
|
||||
}
|
||||
}
|
||||
#else // MMQ
|
||||
const uint hsk4 = HSK_per_thread / 4;
|
||||
const uint d_per_step = (hsk4 % 8 == 0) ? 8 :
|
||||
(hsk4 % 4 == 0) ? 4 :
|
||||
(hsk4 % 2 == 0) ? 2 : 1;
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) {
|
||||
int32_t k_quants[d_per_step];
|
||||
ACC_TYPEV2 k_dm;
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx;
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm);
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
uint vui = kblocksh[buf_ib].qs[d];
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block);
|
||||
const uint ib = coord / BLOCK_SIZE;
|
||||
const uint iqs = (coord % BLOCK_SIZE);
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0);
|
||||
#else
|
||||
k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset));
|
||||
#endif
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
if (d_per_step == 8) {
|
||||
#if defined(DATA_A_Q5_0)
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qh[1]));
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
uint qh = k_packed.k_data_packed16[k_offset + ib].qh;
|
||||
#endif
|
||||
[[unroll]] for (uint32_t d = 0; d < 4; d++) {
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d];
|
||||
#else
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0],
|
||||
k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1]));
|
||||
#endif
|
||||
k_quants[d ] = int32_t( vui & 0x0F0F0F0F);
|
||||
k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F);
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint qh_lo = (qh >> (d * 4)) & 0xF;
|
||||
uint qh_hi = (qh >> (d * 4 + 16)) & 0xF;
|
||||
k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u);
|
||||
k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u);
|
||||
#endif
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8;
|
||||
const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8;
|
||||
|
||||
int32_t acc = 0;
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_step; d++) {
|
||||
acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]);
|
||||
}
|
||||
|
||||
Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x;
|
||||
if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) {
|
||||
Sf[r][c] += k_dot_correction(qib, k_dm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MMQ
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
// Compute sum across the D_split
|
||||
|
||||
@@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32;
|
||||
layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32;
|
||||
#endif
|
||||
|
||||
#ifndef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 1
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q4_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
return int32_t(vui & 0x0F0F0F0F);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
#ifdef DATA_A_Q5_0
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qh[1]));
|
||||
#else
|
||||
uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4];
|
||||
uint qh = k_packed.k_data_packed16[a_offset + ib].qh;
|
||||
#endif
|
||||
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
uint qh_bits = (qh >> iqs) & 0xF;
|
||||
return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
int32_t get_k_qs(uint ib, uint iqs, uint a_offset) {
|
||||
uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0],
|
||||
k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]));
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui >>= shift;
|
||||
|
||||
u8vec4 idx = unpack8(vui & 0x0F0F0F0F);
|
||||
return pack32(i8vec4(kvalues_iq4nl_const[idx.x],
|
||||
kvalues_iq4nl_const[idx.y],
|
||||
kvalues_iq4nl_const[idx.z],
|
||||
kvalues_iq4nl_const[idx.w]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE get_k_d(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d);
|
||||
}
|
||||
#else
|
||||
FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) {
|
||||
return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qh[1]));
|
||||
}
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs];
|
||||
if (iqs == 0) {
|
||||
kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh;
|
||||
}
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2],
|
||||
k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1]));
|
||||
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||
kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y],
|
||||
kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w]));
|
||||
kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y],
|
||||
kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w]));
|
||||
#endif
|
||||
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) {
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||
uint sub = pos % 4;
|
||||
uint shift = ((pos % 8) >= 4) ? 4 : 0;
|
||||
int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F);
|
||||
uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF;
|
||||
return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u);
|
||||
#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
return kblocksh[buf_ib].qs[pos];
|
||||
#endif
|
||||
}
|
||||
|
||||
ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x;
|
||||
#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
return ACC_TYPE(Qf[qib].ds.y) * k_dm.y;
|
||||
#else
|
||||
return ACC_TYPE(0.0);
|
||||
#endif
|
||||
}
|
||||
|
||||
void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) {
|
||||
kblocksh[buf_ib].qs[iqs] = 0;
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
kblocksh[buf_ib].qs[iqs + 4] = 0;
|
||||
#endif
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f);
|
||||
#else
|
||||
kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,12 @@ struct block_a_cache {
|
||||
int32_t qs[32/4];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE dm;
|
||||
};
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
|
||||
@@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
#define QUANT_K QUANT_K_IQ4_NL
|
||||
#define QUANT_R QUANT_R_IQ4_NL
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_iq4_nl
|
||||
#define A_TYPE_PACKED16 block_iq4_nl_packed16
|
||||
#endif
|
||||
|
||||
@@ -406,8 +406,8 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
|
||||
}
|
||||
|
||||
static std::vector<std::future<void>> compiles;
|
||||
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
||||
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
||||
void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") {
|
||||
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix;
|
||||
std::string out_path = join_paths(output_dir, name + ".spv");
|
||||
|
||||
if (input_filepath == "") {
|
||||
@@ -625,15 +625,16 @@ void process_shaders() {
|
||||
for (const bool& fp16 : {false, true}) {
|
||||
std::map<std::string, std::string> base_dict;
|
||||
if (fp16) {
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
|
||||
} else {
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
}
|
||||
|
||||
// flash attention
|
||||
for (const bool& f16acc : {false, true}) {
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2";
|
||||
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
|
||||
if (fp16 && f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
@@ -672,6 +673,12 @@ void process_shaders() {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (tname != "f32") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1115,6 +1115,32 @@ class ggml_webgpu_shader_lib {
|
||||
std::string type_upper = type_str;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
switch (key.src_type)
|
||||
{
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC_TYPE=u32");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
@@ -1125,7 +1151,6 @@ class ggml_webgpu_shader_lib {
|
||||
variant += "_";
|
||||
variant += type_str;
|
||||
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
|
||||
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
@@ -1593,11 +1618,35 @@ class ggml_webgpu_shader_lib {
|
||||
break;
|
||||
default:
|
||||
{
|
||||
// quantized types
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
switch (context.src0->type)
|
||||
{
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
// Quantized types using u32 buffers for portability.
|
||||
defines.push_back("SRC0_TYPE=u32");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
|
||||
@@ -97,6 +97,14 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
||||
|
||||
/* End Constants */
|
||||
|
||||
static inline wgpu::CallbackMode ggml_webgpu_callback_mode() {
|
||||
#ifdef __EMSCRIPTEN__
|
||||
return wgpu::CallbackMode::AllowProcessEvents;
|
||||
#else
|
||||
return wgpu::CallbackMode::AllowSpontaneous;
|
||||
#endif
|
||||
}
|
||||
|
||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
|
||||
// their locations.
|
||||
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
||||
@@ -474,7 +482,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {
|
||||
|
||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||
ctx->queue.OnSubmittedWorkDone(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
ggml_webgpu_callback_mode(),
|
||||
[&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
callback_status = status;
|
||||
callback_message = std::string(message);
|
||||
@@ -494,7 +502,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
|
||||
std::string callback_message;
|
||||
|
||||
const wgpu::WaitStatus wait_status = ctx->instance.WaitAny(
|
||||
buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
|
||||
buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(),
|
||||
[&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
callback_status = status;
|
||||
callback_message = std::string(message);
|
||||
@@ -542,7 +550,7 @@ static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context &
|
||||
auto ts_bufs = command.timestamp_query_bufs;
|
||||
|
||||
wgpu::Future f = ts_bufs.host_buf.MapAsync(
|
||||
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
||||
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(),
|
||||
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
|
||||
@@ -3420,7 +3428,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
&options, ggml_webgpu_callback_mode(),
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
@@ -3449,13 +3457,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Only support square f16 matrices of size 8 or 16 for now
|
||||
// Accept f16 subgroup matrix configurations (square or non-square).
|
||||
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
|
||||
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
|
||||
// The shaders are already parameterized to handle any M/N/K dimensions.
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
||||
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
||||
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
||||
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||
if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
|
||||
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
|
||||
@@ -3491,8 +3501,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
dev_desc.requiredFeatures = required_features.data();
|
||||
dev_desc.requiredFeatureCount = required_features.size();
|
||||
dev_desc.SetDeviceLostCallback(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
ggml_webgpu_callback_mode(),
|
||||
[ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
||||
if (reason == wgpu::DeviceLostReason::Destroyed) {
|
||||
return;
|
||||
}
|
||||
@@ -3525,7 +3535,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->adapter.RequestDevice(
|
||||
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
||||
&dev_desc, ggml_webgpu_callback_mode(),
|
||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
if (status != wgpu::RequestDeviceStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
||||
@@ -3793,6 +3803,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must be divisible by subgroup matrix dimensions
|
||||
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
|
||||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
|
||||
break;
|
||||
}
|
||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
@@ -4046,6 +4061,13 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
ctx.name = GGML_WEBGPU_NAME;
|
||||
ctx.device_count = 0;
|
||||
|
||||
// Keep one Dawn/WebGPU instance alive for the lifetime of the static backend
|
||||
// registry. Recreating it on repeated registry lookups can invalidate
|
||||
// adapter/device references that are still held by the backend/device layer.
|
||||
if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) {
|
||||
return ®
|
||||
}
|
||||
|
||||
wgpu::InstanceDescriptor instance_descriptor{};
|
||||
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
||||
instance_descriptor.requiredFeatures = instance_features.data();
|
||||
@@ -4063,11 +4085,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
|
||||
ctx.webgpu_global_ctx->instance = std::move(inst);
|
||||
|
||||
// Probe for adapter support
|
||||
wgpu::Adapter adapter;
|
||||
if (ctx.webgpu_global_ctx->instance != nullptr) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
// probe for adapter support
|
||||
ctx.webgpu_global_ctx->instance.WaitAny(
|
||||
ctx.webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
|
||||
@@ -9,35 +9,43 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
#endif
|
||||
|
||||
#ifdef U32_DEQUANT_HELPERS
|
||||
fn load_src0_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = src0[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
fn load_u16_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> u32 {
|
||||
let word = buf[byte_offset / 4];
|
||||
let shift = (byte_offset & 0x2) * 8;
|
||||
return (word >> shift) & 0xFFFF;
|
||||
}
|
||||
|
||||
fn load_src0_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = src0[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = src0[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
fn load_u32_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4;
|
||||
let shift = (byte_offset & 0x3) * 8;
|
||||
let lo = buf[word_idx];
|
||||
let hi = buf[word_idx + 1];
|
||||
let shifted = (lo >> shift) | (hi << (32 - shift));
|
||||
return select(shifted, lo, shift == 0);
|
||||
}
|
||||
|
||||
fn load_src0_f16_at(byte_offset: u32) -> f16 {
|
||||
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
|
||||
fn load_f16_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> f16 {
|
||||
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
|
||||
return f16(packed[0]);
|
||||
}
|
||||
|
||||
fn load_f16_as_f32_at(
|
||||
buf: ptr<storage, array<u32>, read_write>,
|
||||
byte_offset: u32) -> f32 {
|
||||
let word = buf[byte_offset / 4];
|
||||
let shift = (byte_offset & 0x2) * 8;
|
||||
let d_bits = (word >> shift) & 0xFFFF;
|
||||
return unpack2x16float(d_bits)[0];
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_0_T
|
||||
struct q4_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef Q4_1_T
|
||||
struct q4_1 {
|
||||
@@ -47,13 +55,6 @@ struct q4_1 {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q5_0_T
|
||||
struct q5_0 {
|
||||
d: f16,
|
||||
qh: array<f16, 2>,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q5_1_T
|
||||
struct q5_1 {
|
||||
@@ -64,12 +65,6 @@ struct q5_1 {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q8_0_T
|
||||
struct q8_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 16>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q8_1_T
|
||||
struct q8_1 {
|
||||
@@ -88,14 +83,6 @@ struct q2_K {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q3_K_T
|
||||
struct q3_K {
|
||||
hmask: array<f16, 16>,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 6>,
|
||||
d: f16
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
|
||||
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
|
||||
@@ -132,64 +119,6 @@ struct q5_K {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef Q6_K_T
|
||||
struct q6_K {
|
||||
ql: array<f16, 64>,
|
||||
qh: array<f16, 32>,
|
||||
scales: array<f16, 8>,
|
||||
d: f16
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XXS_T
|
||||
struct iq2_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_XS_T
|
||||
struct iq2_xs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ2_S_T
|
||||
struct iq2_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_XXS_T
|
||||
struct iq3_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 48>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_S_T
|
||||
struct iq3_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
signs: array<f16, 16>,
|
||||
scales: array<f16, 2>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ1_S_T
|
||||
struct iq1_s {
|
||||
d: f16,
|
||||
qs: array<f16, 16>,
|
||||
qh: array<f16, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ1_M_T
|
||||
struct iq1_m {
|
||||
qs: array<u32, 8>,
|
||||
@@ -198,17 +127,9 @@ struct iq1_m {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_NL_T
|
||||
struct iq4_nl {
|
||||
d: f16,
|
||||
qs: array<f16, 8>,
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef IQ4_XS_T
|
||||
struct iq4_xs {
|
||||
d: f16,
|
||||
scales_h: f16,
|
||||
d_scales_h: u32,
|
||||
scales_l: u32,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
|
||||
@@ -369,35 +369,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
#endif
|
||||
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
|
||||
let inter_offset = kv_block * SG_MAT_N;
|
||||
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
|
||||
var acc: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(&inter_shmem, inter_offset, false, KV_TILE);
|
||||
|
||||
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
|
||||
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, 0u, false, HEAD_DIM_QK);
|
||||
|
||||
#ifdef KV_DIRECT
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + 0u, true, params.stride_k1);
|
||||
#else
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
|
||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
|
||||
var t: u32 = 1u;
|
||||
for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
|
||||
let h0 = t * SG_MAT_K;
|
||||
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
|
||||
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h0, false, HEAD_DIM_QK);
|
||||
#ifdef KV_DIRECT
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h0, true, params.stride_k1);
|
||||
#else
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
|
||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
q_cur = q0;
|
||||
k_cur = k0;
|
||||
|
||||
let h1 = (t + 1u) * SG_MAT_K;
|
||||
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
|
||||
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h1, false, HEAD_DIM_QK);
|
||||
#ifdef KV_DIRECT
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h1, true, params.stride_k1);
|
||||
#else
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
|
||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
q_cur = q1g;
|
||||
@@ -407,11 +407,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
// handle odd tail
|
||||
if (t < HEAD_DIM_QK / SG_MAT_K) {
|
||||
let h = t * SG_MAT_K;
|
||||
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
|
||||
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h, false, HEAD_DIM_QK);
|
||||
#ifdef KV_DIRECT
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h, true, params.stride_k1);
|
||||
#else
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
|
||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
|
||||
#endif
|
||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||
q_cur = qn;
|
||||
@@ -566,7 +566,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
head_dim_block < HEAD_DIM_V;
|
||||
head_dim_block += num_subgroups * SG_MAT_N) {
|
||||
// load O submatrix from shared memory
|
||||
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
|
||||
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(
|
||||
&o_shmem,
|
||||
head_dim_block,
|
||||
false,
|
||||
@@ -574,7 +574,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
);
|
||||
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
|
||||
let p_offset = kv_block * SG_MAT_N;
|
||||
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
|
||||
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(
|
||||
&inter_shmem,
|
||||
p_offset,
|
||||
false,
|
||||
@@ -585,7 +585,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
#ifdef KV_DIRECT
|
||||
let v_block_row = kv_tile + kv_block * SG_MAT_N;
|
||||
let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
|
||||
&V,
|
||||
v_global_offset,
|
||||
false,
|
||||
@@ -593,7 +593,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
);
|
||||
#else
|
||||
let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
|
||||
&kv_shmem,
|
||||
v_block_offset + head_dim_block,
|
||||
false,
|
||||
|
||||
@@ -27,17 +27,18 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef Q4_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q4_0 = src[src_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
|
||||
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
for (var j: u32 = 0u; j < 4; j++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at(&src, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f32(q_byte & 0xFu) - 8.0) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
dst[dst_offset + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,17 +65,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef Q5_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q5_0 = src[src_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
|
||||
let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
let qh_packed = load_u32_at(&src, block_byte_base + 2);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
|
||||
let q_byte_offset = block_byte_base + 6 + j * 4;
|
||||
let q_packed = load_u32_at(&src, q_byte_offset);
|
||||
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
|
||||
let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
dst[dst_offset] = q_lo;
|
||||
dst[dst_offset + 16] = q_hi;
|
||||
@@ -106,14 +112,15 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef Q8_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q8_0 = src[src_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
for (var j: u32 = 0u; j < 8u; j++) {
|
||||
let q_byte_offset = block_byte_base + 2u + j * 4u;
|
||||
let q_packed = load_u32_at(&src, q_byte_offset);
|
||||
for (var k: u32 = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
let dst_offset = dst_base + offset * 32 + j * 4 + k;
|
||||
let dst_offset = dst_base + offset * 32u + j * 4u + k;
|
||||
dst[dst_offset] = q_val;
|
||||
}
|
||||
}
|
||||
@@ -152,36 +159,42 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef Q3_K
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
|
||||
|
||||
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
|
||||
// and 2-bits from the last 4 bytes
|
||||
// Bytes 108-109: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base + 108);
|
||||
|
||||
// Bytes 96-107: 12 bytes of scales (3 u32s)
|
||||
let kmask1: u32 = 0x03030303;
|
||||
let kmask2: u32 = 0x0f0f0f0f;
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
}
|
||||
scale_vals[0] = load_u32_at(&src, block_byte_base + 96);
|
||||
scale_vals[1] = load_u32_at(&src, block_byte_base + 100);
|
||||
scale_vals[2] = load_u32_at(&src, block_byte_base + 104);
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 0-31: 32 bytes of hmask (8 u32s)
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
|
||||
hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 32-95: 64 bytes of qs (16 u32s)
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
|
||||
for (var i: u32 = 0u; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4);
|
||||
}
|
||||
|
||||
var dst_i = dst_base + offset * 256;
|
||||
var is: u32 = 0;
|
||||
var m: u32 = 1;
|
||||
|
||||
// 2 halves of the block (128 elements each)
|
||||
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
|
||||
// 4 groups (each group has 2 blocks of 16 elements)
|
||||
@@ -191,11 +204,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let sc = get_byte(scale_vals[is / 4], is % 4);
|
||||
is++;
|
||||
let dl = d * (f32(sc) - 32.0);
|
||||
for (var l: u32 = 0u; l < 16u; l++) {
|
||||
|
||||
for (var l: u32 = 0; l < 16; l++) {
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
|
||||
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3;
|
||||
dst[dst_i] = (f32(qs_val) - hm) * dl;
|
||||
@@ -268,21 +283,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 208-209: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base + 208);
|
||||
|
||||
// Bytes 0-127: 128 bytes of ql (32 u32s)
|
||||
var ql_vals: array<u32, 32>;
|
||||
for (var i: u32 = 0; i < 32; i++) {
|
||||
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
|
||||
ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 128-191: 64 bytes of qh (16 u32s)
|
||||
var qh_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
|
||||
for (var i: u32 = 0; i < 16u; i++) {
|
||||
qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u);
|
||||
}
|
||||
|
||||
// Bytes 192-207: 16 bytes of scales (4 u32s)
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4);
|
||||
}
|
||||
|
||||
var dst_i = dst_base + offset * 256;
|
||||
@@ -323,12 +344,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ2_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
|
||||
let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
|
||||
let aux0_offset = block_byte_base + 2 + ib * 2;
|
||||
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
|
||||
let aux0 = load_u32_at(&src, aux0_offset);
|
||||
let aux1 = load_u32_at(&src, aux1_offset);
|
||||
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = get_byte(aux0, l) * 8;
|
||||
@@ -345,15 +368,19 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
#ifdef IQ2_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
load_u32_at(&src, block_byte_base + 66),
|
||||
load_u32_at(&src, block_byte_base + 70)
|
||||
);
|
||||
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
|
||||
let db = array<f32, 2>(
|
||||
@@ -361,7 +388,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
|
||||
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
|
||||
let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF;
|
||||
let ig = (qs_val & 511) * 8;
|
||||
let is = qs_val >> 9;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
@@ -379,21 +407,23 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ2_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
||||
var qs_vals : array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
);
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
);
|
||||
|
||||
var qh_vals: array<u32, 2>;
|
||||
qh_vals[0] = load_u32_at(&src, block_byte_base + 66);
|
||||
qh_vals[1] = load_u32_at(&src, block_byte_base + 70);
|
||||
|
||||
var scale_vals: array<u32, 2>;
|
||||
scale_vals[0] = load_u32_at(&src, block_byte_base + 74);
|
||||
scale_vals[1] = load_u32_at(&src, block_byte_base + 78);
|
||||
|
||||
for (var ib: u32 = 0; ib < 8; ib ++) {
|
||||
let s = get_byte(scale_vals[ib / 4], ib % 4);
|
||||
let db = array<f32, 2>(
|
||||
@@ -419,16 +449,17 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ3_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 16; ib += 2) {
|
||||
let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
|
||||
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
|
||||
let sc_sign = load_u32_at(&src, sc_sign_offset);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let is = (sc_sign >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0);
|
||||
let ig2 = get_byte(ig_val, 1);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
@@ -448,18 +479,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ3_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
load_u32_at(&src, block_byte_base + 66),
|
||||
load_u32_at(&src, block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sign_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
|
||||
sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4);
|
||||
}
|
||||
var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
|
||||
|
||||
var scale_vals = load_u32_at(&src, block_byte_base + 106);
|
||||
|
||||
for (var ib: u32 = 0; ib < 4; ib++) {
|
||||
let s = get_byte(scale_vals, ib);
|
||||
let db = array<f32, 2>(
|
||||
@@ -472,7 +507,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let sign_w = sign_vals[ib * 2 + k];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let signs = get_byte(sign_w, l);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
|
||||
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
@@ -493,14 +528,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ1_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
|
||||
let dl = d * (2 * f32((qh >> 12) & 7) + 1);
|
||||
let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF;
|
||||
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
|
||||
let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
|
||||
let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
@@ -560,12 +595,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
|
||||
#ifdef IQ4_NL
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src, block_byte_base);
|
||||
var dst_i = dst_base + offset * 32;
|
||||
var qs: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let qsb = get_byte(qs[j / 4], j % 4);
|
||||
@@ -579,8 +614,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
#ifdef IQ4_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
|
||||
let d = unpack2x16float(block.d_scales_h)[0];
|
||||
let scales_h = block.d_scales_h >> 16;
|
||||
var dst_i = dst_base + offset * 256;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
|
||||
|
||||
@@ -20,11 +20,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef Q4_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q4_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
|
||||
@@ -61,12 +62,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef Q5_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q5_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 2);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
|
||||
let q_byte_offset = block_byte_base + 6 + j * 4;
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
|
||||
@@ -107,12 +109,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef Q8_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q8_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var sum: f32 = 0.0;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte_offset = block_byte_base + 2 + j * 4;
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
|
||||
@@ -178,31 +181,37 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
#ifdef Q3_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
|
||||
// Bytes 108-109: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base + 108);
|
||||
|
||||
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
|
||||
// and 2-bits from the last 4 bytes
|
||||
// Bytes 96-107: 12 bytes of scales (3 u32s)
|
||||
let kmask1: u32 = 0x03030303;
|
||||
let kmask2: u32 = 0x0f0f0f0f;
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
}
|
||||
scale_vals[0] = load_u32_at(&src0, block_byte_base + 96);
|
||||
scale_vals[1] = load_u32_at(&src0, block_byte_base + 100);
|
||||
scale_vals[2] = load_u32_at(&src0, block_byte_base + 104);
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 0-31: 32 bytes of hmask (8 u32s)
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
|
||||
hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 32-95: 64 bytes of qs (16 u32s)
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
|
||||
for (var i: u32 = 0u; i < 16; i++) {
|
||||
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
@@ -301,21 +310,27 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
|
||||
|
||||
// convert arrays of f16 -> u32
|
||||
// Bytes 208-209: f16 scale 'd'
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base + 208);
|
||||
|
||||
// Bytes 0-127: 128 bytes of ql (32 u32s)
|
||||
var ql_vals: array<u32, 32>;
|
||||
for (var i: u32 = 0; i < 32; i++) {
|
||||
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
|
||||
ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 128-191: 64 bytes of qh (16 u32s)
|
||||
var qh_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
|
||||
qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4);
|
||||
}
|
||||
|
||||
// Bytes 192-207: 16 bytes of scales (4 u32s)
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4);
|
||||
}
|
||||
|
||||
var sum = 0.0;
|
||||
@@ -358,13 +373,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ2_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
|
||||
let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
|
||||
let aux0_offset = block_byte_base + 2 + ib * 2;
|
||||
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
|
||||
let aux0 = load_u32_at(&src0, aux0_offset);
|
||||
let aux1 = load_u32_at(&src0, aux1_offset);
|
||||
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = get_byte(aux0, l) * 8;
|
||||
@@ -384,13 +401,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ2_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
load_u32_at(&src0, block_byte_base + 66),
|
||||
load_u32_at(&src0, block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 32; ib += 4) {
|
||||
let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
|
||||
@@ -399,7 +418,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
d * (0.5 + f32(s >> 4)) * 0.25
|
||||
);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
|
||||
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
|
||||
let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF;
|
||||
let ig = (qs_val & 511) * 8;
|
||||
let is = qs_val >> 9;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
@@ -418,21 +438,23 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ2_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qs_vals : array<u32, 16>;
|
||||
for (var i: u32 = 0; i < 16; i++) {
|
||||
qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
);
|
||||
var scale_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.scales[0], block.scales[1])),
|
||||
bitcast<u32>(vec2(block.scales[2], block.scales[3]))
|
||||
);
|
||||
|
||||
var qh_vals: array<u32, 2>;
|
||||
qh_vals[0] = load_u32_at(&src0, block_byte_base + 66);
|
||||
qh_vals[1] = load_u32_at(&src0, block_byte_base + 70);
|
||||
|
||||
var scale_vals: array<u32, 2>;
|
||||
scale_vals[0] = load_u32_at(&src0, block_byte_base + 74);
|
||||
scale_vals[1] = load_u32_at(&src0, block_byte_base + 78);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib ++) {
|
||||
let s = get_byte(scale_vals[ib / 4], ib % 4);
|
||||
@@ -460,17 +482,18 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ3_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 16; ib += 2) {
|
||||
let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
|
||||
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
|
||||
let sc_sign = load_u32_at(&src0, sc_sign_offset);
|
||||
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let is = (sc_sign >> (7 * l)) & 127;
|
||||
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0);
|
||||
let ig2 = get_byte(ig_val, 1);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
@@ -491,18 +514,22 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ3_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
|
||||
var qh_vals = array<u32, 2>(
|
||||
bitcast<u32>(vec2(block.qh[0], block.qh[1])),
|
||||
bitcast<u32>(vec2(block.qh[2], block.qh[3]))
|
||||
load_u32_at(&src0, block_byte_base + 66),
|
||||
load_u32_at(&src0, block_byte_base + 70)
|
||||
);
|
||||
|
||||
var sign_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
|
||||
sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4);
|
||||
}
|
||||
var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
|
||||
|
||||
var scale_vals = load_u32_at(&src0, block_byte_base + 106);
|
||||
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 4; ib++) {
|
||||
let s = get_byte(scale_vals, ib);
|
||||
@@ -516,7 +543,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let sign_w = sign_vals[ib * 2 + k];
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let signs = get_byte(sign_w, l);
|
||||
let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
|
||||
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
|
||||
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
|
||||
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
|
||||
for (var j: u32 = 0; j < 4; j++) {
|
||||
@@ -538,15 +565,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ1_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
|
||||
let dl = d * (2 * f32((qh >> 12) & 7) + 1);
|
||||
let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF;
|
||||
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
|
||||
let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
|
||||
let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4);
|
||||
for (var l: u32 = 0; l < 4; l++) {
|
||||
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
|
||||
for (var j: u32 = 0; j < 8; j++) {
|
||||
@@ -610,13 +637,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
|
||||
#ifdef IQ4_NL
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
|
||||
let d = load_f16_as_f32_at(&src0, block_byte_base);
|
||||
var src1_i = src1_idx_base + offset * 32;
|
||||
var sum = 0.0;
|
||||
var qs: array<u32, 4>;
|
||||
for (var i: u32 = 0; i < 4; i++) {
|
||||
qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
|
||||
qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
|
||||
}
|
||||
for (var j: u32 = 0; j < 16; j++) {
|
||||
let qsb = get_byte(qs[j / 4], j % 4);
|
||||
@@ -631,8 +658,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
#ifdef IQ4_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
|
||||
let d = unpack2x16float(block.d_scales_h)[0];
|
||||
let scales_h = block.d_scales_h >> 16;
|
||||
var src1_i = src1_idx_base + offset * 256;
|
||||
var sum = 0.0;
|
||||
for (var ib: u32 = 0; ib < 8; ib++) {
|
||||
|
||||
@@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
@@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f16(q_byte & 0xF) * d + m;
|
||||
@@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
@@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
@@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base + 80u);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 82u);
|
||||
let d = load_f16_at(&src0, block_byte_base + 80u);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 82u);
|
||||
|
||||
// Decode the element at position k_in_block
|
||||
let block_of_32 = k_in_block / 32u;
|
||||
@@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
let is = k_in_block / 16u;
|
||||
|
||||
let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
|
||||
let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u));
|
||||
let sc = get_byte(sc_packed, is % 4u);
|
||||
|
||||
let dl = d * f16(sc & 0xFu);
|
||||
let ml = dmin * f16(sc >> 4u);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
@@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base + 108u);
|
||||
let d = load_f16_at(&src0, block_byte_base + 108u);
|
||||
|
||||
// Load and unpack scales
|
||||
let kmask1: u32 = 0x03030303u;
|
||||
@@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0u; i < 4u; i++) {
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i);
|
||||
}
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
@@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
// Load hmask and qs arrays
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0u; i < 8u; i++) {
|
||||
hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
|
||||
hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i);
|
||||
}
|
||||
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16u; i++) {
|
||||
qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
|
||||
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i);
|
||||
}
|
||||
|
||||
let half = k_in_block / 128u; // 0 or 1
|
||||
@@ -499,13 +499,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
@@ -541,7 +541,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
@@ -575,13 +575,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = load_f16_at(&src0, block_byte_base);
|
||||
let dmin = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
@@ -621,11 +621,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
|
||||
let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u));
|
||||
|
||||
let qh_byte = get_byte(qh_packed, l % 4u);
|
||||
|
||||
@@ -673,17 +673,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
// Load only ql13 word needed
|
||||
let ql13_flat = ql_b_idx + l;
|
||||
let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
|
||||
let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat);
|
||||
let ql13_b = get_byte(ql13, 0u);
|
||||
|
||||
// Load only ql24 word needed
|
||||
let ql24_flat = ql_b_idx + l + 32u;
|
||||
let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
|
||||
let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat);
|
||||
let ql24_b = get_byte(ql24, 0u);
|
||||
|
||||
// Load only qh word needed
|
||||
let qh_flat = qh_b_idx + l;
|
||||
let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
|
||||
let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat);
|
||||
let qh_b = get_byte(qh, 0u);
|
||||
|
||||
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
|
||||
@@ -694,10 +694,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
// Load only the scale word needed
|
||||
let is = l / 16u;
|
||||
let sc_idx = sc_b_idx + is + quarter * 2u;
|
||||
let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
|
||||
let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx);
|
||||
let sc_val = get_byte_i32(sc, 0u);
|
||||
|
||||
let d = load_src0_f16_at(block_byte_base + 208u);
|
||||
let d = load_f16_at(&src0, block_byte_base + 208u);
|
||||
|
||||
var q_val: f16;
|
||||
if (quarter == 0u) {
|
||||
|
||||
@@ -65,10 +65,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
@@ -98,11 +98,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = f32(load_src0_f16_at(block_byte_base + 2u));
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let m = f32(load_f16_at(&src0, block_byte_base + 2u));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
@@ -132,12 +132,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -176,13 +176,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
@@ -221,11 +221,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
@@ -254,12 +254,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let d = f32(load_f16_at(&src0, block_byte_base));
|
||||
let m = load_f16_at(&src0, block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
let q_packed = load_u32_at(&src0, q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + f32(m);
|
||||
@@ -309,13 +309,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
for (var i = ix; i < nb; i += 2u) {
|
||||
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = f32(load_src0_f16_at(bbase + 208u));
|
||||
let d = f32(load_f16_at(&src0, bbase + 208u));
|
||||
|
||||
let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
|
||||
let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
|
||||
let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
|
||||
let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
|
||||
let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
|
||||
let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l);
|
||||
let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u);
|
||||
let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h);
|
||||
let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte);
|
||||
let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u);
|
||||
|
||||
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
|
||||
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
|
||||
|
||||
@@ -107,7 +107,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
|
||||
#endif
|
||||
#ifdef EXP
|
||||
let res = exp(src[params.offset_src + src_idx]);
|
||||
let src_f32 = f32(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(exp(src_f32));
|
||||
#endif
|
||||
#ifdef LOG
|
||||
let res = TYPE(log(f32(src[params.offset_src + src_idx])));
|
||||
@@ -161,7 +162,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
|
||||
#endif
|
||||
#ifdef EXPM1
|
||||
let res = exp(src[params.offset_src + src_idx]) - 1.0;
|
||||
let src_f32 = f32(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(exp(src_f32) - 1.0);
|
||||
#endif
|
||||
#ifdef FLOOR
|
||||
let res = floor(src[params.offset_src + src_idx]);
|
||||
@@ -181,7 +183,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];
|
||||
#endif
|
||||
#ifdef SQRT
|
||||
let res = sqrt(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(sqrt(f32(src[params.offset_src + src_idx])));
|
||||
#endif
|
||||
#ifdef SIN
|
||||
let res_f32 = sin(f32(src[params.offset_src + src_idx]));
|
||||
|
||||
@@ -798,6 +798,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
A_ENC_INP_PROJ = auto() # gemma4
|
||||
A_ENC_CONV1D = auto()
|
||||
A_ENC_CONV1D_NORM = auto() # gemma3n
|
||||
A_ENC_CONV2D = auto()
|
||||
A_ENC_CONV_OUT = auto()
|
||||
A_PRE_NORM = auto()
|
||||
A_POST_NORM = auto()
|
||||
A_ENC_LAYER_PRE_NORM = auto() # gemma3n
|
||||
@@ -1280,6 +1282,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ: "a.input_projection",
|
||||
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||
MODEL_TENSOR.A_ENC_CONV2D: "a.conv2d.{bid}",
|
||||
MODEL_TENSOR.A_ENC_CONV_OUT: "a.conv_out",
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM: "a.conv1d.{bid}.norm",
|
||||
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
|
||||
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
|
||||
@@ -1426,6 +1430,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
|
||||
MODEL_TENSOR.A_ENC_INP_PROJ,
|
||||
MODEL_TENSOR.A_ENC_CONV1D,
|
||||
MODEL_TENSOR.A_ENC_CONV2D,
|
||||
MODEL_TENSOR.A_ENC_CONV_OUT,
|
||||
MODEL_TENSOR.A_ENC_CONV1D_NORM,
|
||||
MODEL_TENSOR.A_PRE_NORM,
|
||||
MODEL_TENSOR.A_POST_NORM,
|
||||
@@ -4112,9 +4118,11 @@ class VisionProjectorType:
|
||||
ULTRAVOX = "ultravox"
|
||||
INTERNVL = "internvl"
|
||||
QWEN2A = "qwen2a" # audio
|
||||
QWEN3A = "qwen3a" # audio
|
||||
GLMA = "glma" # audio
|
||||
QWEN25O = "qwen2.5o" # omni
|
||||
VOXTRAL = "voxtral"
|
||||
MERALION = "meralion" # audio: Whisper + gated MLP adaptor
|
||||
LFM2 = "lfm2"
|
||||
KIMIVL = "kimivl"
|
||||
PADDLEOCR = "paddleocr"
|
||||
|
||||
@@ -1892,6 +1892,14 @@ class TensorNameMap:
|
||||
"conformer.subsample_conv_projection.input_proj_linear", # gemma4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_CONV2D: (
|
||||
"audio_tower.conv2d{bid}", # qwen3omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_CONV_OUT: (
|
||||
"audio_tower.conv_out", # qwen3omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_PRE_NORM: (),
|
||||
|
||||
MODEL_TENSOR.A_POST_NORM: (
|
||||
@@ -2041,8 +2049,9 @@ class TensorNameMap:
|
||||
# this prefix is added in the conversion code in modify_tensors()
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ: (
|
||||
"audio.multi_modal_projector.linear_{bid}", # ultravox
|
||||
"audio_adapter.model.{bid}" # lfm2
|
||||
"audio.multi_modal_projector.linear_{bid}", # ultravox, meralion
|
||||
"audio_adapter.model.{bid}", # lfm2
|
||||
"audio_tower.proj{bid}", # qwen3omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ_FC: (
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
{%- if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = false -%}
|
||||
{%- endif -%}
|
||||
{%- if not thinking is defined -%}
|
||||
{%- if enable_thinking is defined -%}
|
||||
{%- set thinking = enable_thinking -%}
|
||||
{%- else -%}
|
||||
{%- set thinking = false -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set dsml_token = '|DSML|' -%}
|
||||
{%- set thinking_start_token = '<think>' -%}
|
||||
{%- set thinking_end_token = '</think>' -%}
|
||||
{%- set tools_header = '## Tools\n\nYou have access to a set of tools you can use to answer the user\'s question.\nYou can invoke functions by writing a "<' + dsml_token + 'function_calls>" block like the following as part of your reply to the user:\n<' + dsml_token + 'function_calls>\n<' + dsml_token + 'invoke name="$FUNCTION_NAME">\n<' + dsml_token + 'parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</' + dsml_token + 'parameter>\n...\n</' + dsml_token + 'invoke>\n<' + dsml_token + 'invoke name="$FUNCTION_NAME2">\n...\n</' + dsml_token + 'invoke>\n</' + dsml_token + 'function_calls>\n\nString and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects).\n\nIf the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:\n\n<' + dsml_token + 'function_calls>\n...\n</' + dsml_token + 'function_calls>\n\n<function_results>\n...\n</function_results>\n\n' + thinking_start_token + '...thinking about results' + thinking_end_token + '\n\nHere are the functions available in JSONSchema format:\n<functions>\n' -%}
|
||||
{%- set tools_footer = '</functions>\n' -%}
|
||||
{%- set ns = namespace(system_prompt='', is_first_sp=true) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- if ns.is_first_sp -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + (message['content'] or '') -%}
|
||||
{%- set ns.is_first_sp = false -%}
|
||||
{%- else -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + (message['content'] or '') -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if tools is defined and tools -%}
|
||||
{%- set ts = namespace(schemas='') -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool['type'] == 'function' -%}
|
||||
{%- set ts.schemas = ts.schemas + (tool['function'] | tojson) + '\n' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + tools_header + ts.schemas + tools_footer -%}
|
||||
{%- else -%}
|
||||
{%- set ns.system_prompt = tools_header + ts.schemas + tools_footer -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{{- bos_token -}}
|
||||
{{- ns.system_prompt -}}
|
||||
{%- set last_user_idx = namespace(value=-1) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' or message['role'] == 'developer' -%}
|
||||
{%- set last_user_idx.value = loop.index0 -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set state = namespace(pending_asst_marker=false, pending_tool_marker=false) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '<|User|>' + (message['content'] or '') -}}
|
||||
{%- set state.pending_asst_marker = true -%}
|
||||
{%- set state.pending_tool_marker = false -%}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{%- set is_after_last_user = loop.index0 > last_user_idx.value -%}
|
||||
{%- if state.pending_asst_marker -%}
|
||||
{{- '<|Assistant|>' -}}
|
||||
{%- if is_after_last_user and thinking -%}
|
||||
{{- thinking_start_token -}}
|
||||
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
|
||||
{{- message['reasoning_content'] -}}
|
||||
{%- endif -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- else -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- elif state.pending_tool_marker -%}
|
||||
{%- if is_after_last_user and thinking -%}
|
||||
{{- '\n\n' + thinking_start_token -}}
|
||||
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
|
||||
{{- message['reasoning_content'] -}}
|
||||
{%- endif -%}
|
||||
{{- thinking_end_token -}}
|
||||
{%- else -%}
|
||||
{{- '\n\n' + thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- set state.pending_asst_marker = false -%}
|
||||
{%- set state.pending_tool_marker = false -%}
|
||||
{%- if message['content'] is defined and message['content'] -%}
|
||||
{{- message['content'] -}}
|
||||
{%- endif -%}
|
||||
{%- if message['tool_calls'] -%}
|
||||
{{- '\n\n<' + dsml_token + 'function_calls>\n' -}}
|
||||
{%- for tool in message['tool_calls'] -%}
|
||||
{%- set func = tool['function'] -%}
|
||||
{{- '<' + dsml_token + 'invoke name="' + func['name'] + '">\n' -}}
|
||||
{%- set args = func['arguments'] -%}
|
||||
{%- if args is string -%}
|
||||
{%- set args = args | from_json -%}
|
||||
{%- endif -%}
|
||||
{%- for key, val in args.items() -%}
|
||||
{%- if val is string -%}
|
||||
{{- '<' + dsml_token + 'parameter name="' + key + '" string="true">' + val + '</' + dsml_token + 'parameter>\n' -}}
|
||||
{%- else -%}
|
||||
{{- '<' + dsml_token + 'parameter name="' + key + '" string="false">' + (val | tojson) + '</' + dsml_token + 'parameter>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '</' + dsml_token + 'invoke>\n' -}}
|
||||
{%- endfor -%}
|
||||
{{- '</' + dsml_token + 'function_calls>' -}}
|
||||
{%- endif -%}
|
||||
{{- '<|end▁of▁sentence|>' -}}
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- set outer_index = loop.index0 -%}
|
||||
{%- set assistant_idx = namespace(value=-1) -%}
|
||||
{%- for prev_msg in messages -%}
|
||||
{%- if prev_msg['role'] == 'assistant' and prev_msg['tool_calls'] and loop.index0 < outer_index -%}
|
||||
{%- set assistant_idx.value = loop.index0 -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set call_order = outer_index - assistant_idx.value -%}
|
||||
{%- set assistant_msg = messages[assistant_idx.value] -%}
|
||||
{%- set tool_call_count = assistant_msg['tool_calls'] | length -%}
|
||||
{%- if call_order == 1 -%}
|
||||
{{- '\n\n<function_results>' -}}
|
||||
{%- endif -%}
|
||||
{{- '\n<result>' + (message['content'] or '') + '</result>' -}}
|
||||
{%- if call_order == tool_call_count -%}
|
||||
{{- '\n</function_results>' -}}
|
||||
{%- set state.pending_asst_marker = false -%}
|
||||
{%- set state.pending_tool_marker = true -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if state.pending_asst_marker -%}
|
||||
{{- '<|Assistant|>' -}}
|
||||
{%- if thinking -%}
|
||||
{{- thinking_start_token -}}
|
||||
{%- else -%}
|
||||
{{- thinking_start_token + thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- elif state.pending_tool_marker -%}
|
||||
{%- if thinking -%}
|
||||
{{- '\n\n' + thinking_start_token -}}
|
||||
{%- else -%}
|
||||
{{- '\n\n' + thinking_start_token + thinking_end_token -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -152,14 +152,14 @@
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None, last_user_message=-1) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{ bos_token }}
|
||||
{{- bos_token -}}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>' -}}
|
||||
{{- '<|think|>\n' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
@@ -255,13 +255,13 @@
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '\n\n<|image|>\n\n' -}}
|
||||
{{- '<|image|>' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '\n\n<|video|>\n\n' -}}
|
||||
{{- '<|video|>' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
@@ -11,34 +11,15 @@
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||
,properties:{
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
{%- elif value is mapping -%}
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- if value['required'] -%}
|
||||
,required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
,items:{
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
@@ -71,6 +52,32 @@
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'OBJECT' -%}
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
}
|
||||
{%- elif value is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- if value['required'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
@@ -150,16 +157,31 @@
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- macro format_tool_response_block(tool_name, response) -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if response is mapping -%}
|
||||
{{- 'response:' + tool_name + '{' -}}
|
||||
{%- for key, value in response | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{ bos_token }}
|
||||
{{- bos_token -}}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>' -}}
|
||||
{{- '<|think|>\n' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
@@ -180,11 +202,41 @@
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Pre-scan: find last user message index for reasoning guard -#}
|
||||
{%- set ns_turn = namespace(last_user_idx=-1) -%}
|
||||
{%- for i in range(loop_messages | length) -%}
|
||||
{%- if loop_messages[i]['role'] == 'user' -%}
|
||||
{%- set ns_turn.last_user_idx = i -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if message['role'] != 'tool' -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
|
||||
{%- set prev_nt = namespace(role=None, found=false) -%}
|
||||
{%- if loop.index0 > 0 -%}
|
||||
{%- for j in range(loop.index0 - 1, -1, -1) -%}
|
||||
{%- if not prev_nt.found -%}
|
||||
{%- if loop_messages[j]['role'] != 'tool' -%}
|
||||
{%- set prev_nt.role = loop_messages[j]['role'] -%}
|
||||
{%- set prev_nt.found = true -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
|
||||
{%- if not continue_same_model_turn -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Render reasoning/reasoning_content as thinking channel -#}
|
||||
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
|
||||
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
|
||||
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
@@ -205,23 +257,49 @@
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_responses'] -%}
|
||||
{#- Tool Response handling -#}
|
||||
{%- set ns_tr_out = namespace(flag=false) -%}
|
||||
{%- if message.get('tool_responses') -%}
|
||||
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if tool_response['response'] is mapping -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
|
||||
{%- for key, value in tool_response['response'] | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endfor -%}
|
||||
{%- elif message.get('tool_calls') -%}
|
||||
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
|
||||
{%- set ns_tool_scan = namespace(stopped=false) -%}
|
||||
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
|
||||
{%- if ns_tool_scan.stopped -%}
|
||||
{%- elif loop_messages[k]['role'] != 'tool' -%}
|
||||
{%- set ns_tool_scan.stopped = true -%}
|
||||
{%- else -%}
|
||||
{%- set follow = loop_messages[k] -%}
|
||||
{#- Resolve tool_call_id to function name -#}
|
||||
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
|
||||
{%- for tc in message['tool_calls'] -%}
|
||||
{%- if tc.get('id') == follow.get('tool_call_id') -%}
|
||||
{%- set ns_tname.name = tc['function']['name'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{#- Handle content as string or content-parts array -#}
|
||||
{%- set tool_body = follow.get('content') -%}
|
||||
{%- if tool_body is string -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- elif tool_body is sequence and tool_body is not string -%}
|
||||
{%- set ns_txt = namespace(s='') -%}
|
||||
{%- for part in tool_body -%}
|
||||
{%- if part.get('type') == 'text' -%}
|
||||
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
|
||||
{%- else -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- endif -%}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
@@ -239,28 +317,31 @@
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '\n\n<|image|>\n\n' -}}
|
||||
{{- '<|image|>' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '\n\n<|video|>\n\n' -}}
|
||||
{{- '<|video|>' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not (message['tool_responses'] and not message['content']) -%}
|
||||
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- if not enable_thinking | default(false) -%}
|
||||
{{- '<|channel>thought\n<channel|>' -}}
|
||||
{%- if not enable_thinking | default(false) -%}
|
||||
{{- '<|channel>thought\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
|
||||
@@ -8,5 +8,5 @@ pandas~=2.2.3
|
||||
prometheus-client~=0.20.0
|
||||
requests~=2.32.3
|
||||
wget~=3.2
|
||||
typer~=0.15.1
|
||||
typer~=0.24.1
|
||||
seaborn~=0.13.2
|
||||
|
||||
@@ -22,9 +22,6 @@ device="HTP0"
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
profile=
|
||||
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v"
|
||||
|
||||
@@ -46,7 +43,7 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$ndev $nhvx $opmask $verbose $experimental $profile $hb ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
$ndev $nhvx $opmask $verbose $profile $hb ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ubatch-size 256 -fa 1 -ngl 99 $cli_opts $@ \
|
||||
"
|
||||
|
||||
@@ -21,9 +21,6 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
@@ -48,13 +45,22 @@ ndev=
|
||||
hb=
|
||||
[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB"
|
||||
|
||||
opbatch=
|
||||
[ "$OB" != "" ] && opbatch="GGML_HEXAGON_OPBATCH=$OB"
|
||||
|
||||
opqueue=
|
||||
[ "$OQ" != "" ] && opqueue="GGML_HEXAGON_OPQUEUE=$OQ"
|
||||
|
||||
opflt=
|
||||
[ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $opflt \
|
||||
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
|
||||
@@ -21,9 +21,6 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||
device="HTP0"
|
||||
[ "$D" != "" ] && device="$D"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
|
||||
|
||||
@@ -48,13 +45,22 @@ ndev=
|
||||
hb=
|
||||
[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB"
|
||||
|
||||
opbatch=
|
||||
[ "$OB" != "" ] && opbatch="GGML_HEXAGON_OPBATCH=$OB"
|
||||
|
||||
opqueue=
|
||||
[ "$OQ" != "" ] && opqueue="GGML_HEXAGON_OPQUEUE=$OQ"
|
||||
|
||||
opflt=
|
||||
[ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $opflt \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
|
||||
@@ -21,9 +21,6 @@ device="HTP0"
|
||||
verbose=
|
||||
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V"
|
||||
|
||||
experimental=
|
||||
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||
|
||||
sched=
|
||||
[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v"
|
||||
|
||||
@@ -53,5 +50,5 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $hmx $ndev $hb ./$branch/bin/$tool $@ \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb ./$branch/bin/$tool $@ \
|
||||
"
|
||||
|
||||
@@ -20,10 +20,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:PROF) {
|
||||
$env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1
|
||||
}
|
||||
|
||||
@@ -20,10 +20,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
@@ -20,10 +20,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
@@ -29,12 +29,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
# Default experimental to 1
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=1
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
@@ -26,10 +26,6 @@ if ($null -ne $env:V) {
|
||||
$env:GGML_HEXAGON_VERBOSE=$env:V
|
||||
}
|
||||
|
||||
if ($null -ne $env:E) {
|
||||
$env:GGML_HEXAGON_EXPERIMENTAL=$env:E
|
||||
}
|
||||
|
||||
if ($null -ne $env:SCHED) {
|
||||
$env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v"
|
||||
}
|
||||
|
||||
+40
-23
@@ -202,24 +202,37 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
|
||||
const int64_t n_v_heads = hparams.ssm_dt_rank;
|
||||
const int64_t key_dim = head_k_dim * n_k_heads;
|
||||
const int64_t value_dim = head_v_dim * n_v_heads;
|
||||
const int64_t head_ratio = n_v_heads / n_k_heads;
|
||||
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
|
||||
GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
|
||||
return std::vector<int64_t>(2 + head_ratio, key_dim);
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
|
||||
return std::vector<int64_t>(head_ratio, key_dim);
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
|
||||
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
|
||||
return std::vector<int64_t>(head_ratio, n_k_heads);
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_r_cache)) {
|
||||
return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1));
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_s_cache)) {
|
||||
return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim);
|
||||
|
||||
// both Qwen 3 Next and Qwen 3.5 support n_v_heads > n_k_heads but the broadcasting pattern is different:
|
||||
// - Qwen 3 Next: [k0_v0, k0_v1, k1_v2, k1_v3] (this is the default split pattern)
|
||||
// - Qwen 3.5: [k0_v0, k1_v1, k0_v2, k1_v3] (needs segmenting of V on the scale of K to get the correct pattern)
|
||||
if (ud->model->arch == LLM_ARCH_QWEN3NEXT) {
|
||||
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
|
||||
GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
|
||||
return {key_dim, key_dim, value_dim};
|
||||
}
|
||||
} else {
|
||||
const int64_t head_ratio = n_v_heads / n_k_heads;
|
||||
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
|
||||
GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
|
||||
return std::vector<int64_t>(2 + head_ratio, key_dim);
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
|
||||
return std::vector<int64_t>(head_ratio, key_dim);
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
|
||||
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
|
||||
return std::vector<int64_t>(head_ratio, n_k_heads);
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_r_cache)) {
|
||||
return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1));
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_s_cache)) {
|
||||
return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim);
|
||||
}
|
||||
}
|
||||
|
||||
// the FFN is the same for Qwen 3 Next and Qwen 3.5:
|
||||
if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) {
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp);
|
||||
@@ -249,13 +262,16 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
|
||||
const int64_t head_dim = hparams.ssm_d_state;
|
||||
const int64_t granularity_qkv = std::lcm(blck_size, head_dim);
|
||||
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) ||
|
||||
std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
|
||||
std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
|
||||
return std::vector<int64_t>(segments.size(), granularity_qkv);
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
|
||||
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
|
||||
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
|
||||
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
|
||||
return std::vector<int64_t>(segments.size(), granularity_qkv / head_dim);
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_ssm_beta_alpha)) {
|
||||
return std::vector<int64_t>(segments.size(), 2 * (granularity_qkv / head_dim));
|
||||
}
|
||||
if (std::regex_match(tensor_name, pattern_r_cache)) {
|
||||
return std::vector<int64_t>(segments.size(), granularity_qkv * (hparams.ssm_d_conv - 1));
|
||||
}
|
||||
@@ -300,7 +316,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
|
||||
|
||||
// FFN
|
||||
if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) ||
|
||||
std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
|
||||
std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
|
||||
GGML_ASSERT(segments.size() <= 2);
|
||||
return std::vector<int64_t>(segments.size(), blck_size);
|
||||
}
|
||||
@@ -4623,17 +4639,18 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(i);
|
||||
const int64_t n_embd_k = hparams.n_embd_k_gqa(i);
|
||||
const int64_t n_embd_v = hparams.n_embd_v_gqa(i);
|
||||
const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED;
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj)
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
@@ -354,7 +354,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
cb(last_conv_states, "last_conv_states", il);
|
||||
|
||||
ggml_tensor * state_update_target =
|
||||
ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
|
||||
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
|
||||
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
|
||||
cb(state_update_target, "state_update_target", il);
|
||||
|
||||
@@ -445,7 +445,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
|
||||
@@ -258,6 +258,66 @@ void test_gbnf_generation(testing &t) {
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("silent parser emits nothing in gbnf", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.gbnf(p.literal("world"), "");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("silent choice inside sequence emits nothing", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.gbnf(p.literal("b") | p.literal("c"), "") + p.literal("d");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a" "d"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("silent wrapped in tag emits nothing", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.tag("t", p.gbnf(p.literal("b"), ""));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("gbnf parser emits custom grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("a") + p.gbnf(p.literal("b"), "[a-z]+");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a" [a-z]+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("nested transparent wrappers get parenthesized", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("x") + p.tag("outer", p.atomic(p.literal("a") | p.literal("b")));
|
||||
|
||||
@@ -8397,6 +8397,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 512, 1, 1}, order)); // test CUDA dispatching to radix sort for nrows > = 1 in graph mode
|
||||
}
|
||||
|
||||
for (int n = 1; n < 5; ++n) {
|
||||
@@ -8579,7 +8580,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (int nb : { 1, 3, 32, 75, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL}) {
|
||||
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
|
||||
|
||||
@@ -2118,6 +2118,31 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.tools({ amount_tool })
|
||||
.expect(message_with_tool_calls("amount", R"({"orig": 1.5e10})"))
|
||||
.run();
|
||||
|
||||
// Edge cases
|
||||
tst.test(
|
||||
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<channel|>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<|channel>thought\n<channel|>Hello, world!\nWhat's up?<|channel>thought\n<channel|><channel|>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<|channel><|channel>thought\n<channel|>Hello, world!\nWhat's up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
@@ -2576,6 +2601,215 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
expect(simple_assist_msg("CONTENT", "")).run();
|
||||
}
|
||||
|
||||
// DeepSeek V3.2 tests - format uses DSML markup:
|
||||
// <|DSML|function_calls>
|
||||
// <|DSML|invoke name="foo">
|
||||
// <|DSML|parameter name="bar" string="true|false">value</|DSML|parameter>
|
||||
// </|DSML|invoke>
|
||||
// </|DSML|function_calls>
|
||||
// Reasoning uses <think>...</think>. The generation prompt ends in <think> (thinking mode)
|
||||
// or <think></think> (non-thinking mode).
|
||||
{
|
||||
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.2.jinja", detailed_debug);
|
||||
|
||||
// Pure content (non-thinking mode)
|
||||
tst.test("Hello, world!\nWhat's up?")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
// Thinking + content
|
||||
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.expect(message_assist_thoughts)
|
||||
.run();
|
||||
|
||||
// Thinking + tool call (single, string param)
|
||||
tst.test(
|
||||
"Let me check the time</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"get_time\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">Tokyo</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls_and_reasoning("get_time", R"({"city": "Tokyo"})", "Let me check the time"))
|
||||
.run();
|
||||
|
||||
// Tool call without reasoning (non-thinking mode), integer param (string="false")
|
||||
tst.test(
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"special_function\">\n"
|
||||
"<|DSML|parameter name=\"arg1\" string=\"false\">1</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
|
||||
// Multiple parallel tool calls with reasoning
|
||||
tst.test(
|
||||
"Calling both</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"get_time\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">Paris</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"<|DSML|invoke name=\"get_weather\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">Paris</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.parallel_tool_calls(true)
|
||||
.tools({ get_time_tool, get_weather_tool })
|
||||
.expect(message_with_reasoning_content_and_multiple_tool_calls(
|
||||
"Calling both", "",
|
||||
{ { "get_time", R"({"city": "Paris"})" }, { "get_weather", R"({"city": "Paris"})" } }))
|
||||
.run();
|
||||
|
||||
// Tool call with content before tool calls
|
||||
tst.test(
|
||||
"Thinking about it</think>"
|
||||
"Let me call the function.\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"special_function\">\n"
|
||||
"<|DSML|parameter name=\"arg1\" string=\"false\">1</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ special_function_tool })
|
||||
.expect_reasoning("Thinking about it")
|
||||
.expect_content("Let me call the function.")
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with negative number
|
||||
tst.test(
|
||||
"Test negative</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"magic_int\">\n"
|
||||
"<|DSML|parameter name=\"ref\" string=\"false\">-14</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ magic_int_tool })
|
||||
.expect_reasoning("Test negative")
|
||||
.expect_tool_calls({
|
||||
{ "magic_int", R"({"ref": -14})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with decimal number
|
||||
tst.test(
|
||||
"Test decimal</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"amount\">\n"
|
||||
"<|DSML|parameter name=\"orig\" string=\"false\">3.14</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ amount_tool })
|
||||
.expect_reasoning("Test decimal")
|
||||
.expect_tool_calls({
|
||||
{ "amount", R"({"orig": 3.14})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with boolean
|
||||
tst.test(
|
||||
"Test boolean</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"toggle\">\n"
|
||||
"<|DSML|parameter name=\"enabled\" string=\"false\">true</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ toggle_tool })
|
||||
.expect_reasoning("Test boolean")
|
||||
.expect_tool_calls({
|
||||
{ "toggle", R"({"enabled": true})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with array parameter (JSON-formatted)
|
||||
tst.test(
|
||||
"Test array</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"todo_list\">\n"
|
||||
"<|DSML|parameter name=\"todos\" string=\"false\">[\"buy milk\",\"walk dog\"]</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ todo_list })
|
||||
.expect_reasoning("Test array")
|
||||
.expect_tool_calls({
|
||||
{ "todo_list", R"({"todos": ["buy milk", "walk dog"]})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with object parameter (JSON-formatted)
|
||||
tst.test(
|
||||
"Test object</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"set_config\">\n"
|
||||
"<|DSML|parameter name=\"config\" string=\"false\">{\"theme\":\"dark\",\"level\":2}</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ config_tool })
|
||||
.expect_reasoning("Test object")
|
||||
.expect_tool_calls({
|
||||
{ "set_config", R"({"config": {"theme": "dark", "level": 2}})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Edge case: empty reasoning
|
||||
tst.test(
|
||||
"</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"get_time\">\n"
|
||||
"<|DSML|parameter name=\"city\" string=\"true\">XYZCITY</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls("get_time", R"({"city": "XYZCITY"})"))
|
||||
.run();
|
||||
|
||||
// Edge case: tool call with multiple params (mixed types, string first)
|
||||
tst.test(
|
||||
"Multi-arg call</think>\n\n"
|
||||
"<|DSML|function_calls>\n"
|
||||
"<|DSML|invoke name=\"magic_int\">\n"
|
||||
"<|DSML|parameter name=\"ref\" string=\"false\">42</|DSML|parameter>\n"
|
||||
"<|DSML|parameter name=\"name\" string=\"true\">foo bar</|DSML|parameter>\n"
|
||||
"</|DSML|invoke>\n"
|
||||
"</|DSML|function_calls>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ magic_int_tool })
|
||||
.expect_reasoning("Multi-arg call")
|
||||
.expect_tool_calls({
|
||||
{ "magic_int", R"({"ref": 42, "name": "foo bar"})", {} },
|
||||
})
|
||||
.run();
|
||||
}
|
||||
|
||||
// GLM-4.6 tests - format: <tool_call>function_name\n<arg_key>...</arg_key>\n<arg_value>...</arg_value>\n</tool_call>
|
||||
{
|
||||
auto tst = peg_tester("models/templates/GLM-4.6.jinja", detailed_debug);
|
||||
|
||||
@@ -88,6 +88,11 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
|
||||
uint32_t n_layer = 2;
|
||||
if (arch == LLM_ARCH_LLAMA4) {
|
||||
n_layer = 4; // hparams.n_no_rope_layer_step is hard-coded to 4
|
||||
} else if (arch == LLM_ARCH_GEMMA4) {
|
||||
n_embd = 128;
|
||||
n_head = 2;
|
||||
n_ff = 192;
|
||||
n_layer = 5; // need at least 5 for swa_pattern (every 5th is full_attention)
|
||||
} else if (arch == LLM_ARCH_GEMMA3N) {
|
||||
n_embd = 64;
|
||||
n_head = 1;
|
||||
@@ -169,7 +174,15 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
|
||||
ms.add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, uint32_t(8));
|
||||
ms.add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, n_ctx/8);
|
||||
|
||||
if (arch == LLM_ARCH_MIMO2 || arch == LLM_ARCH_STEP35) {
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
ms.add_kv(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, n_embd/2);
|
||||
ms.add_kv(LLM_KV_ATTENTION_SHARED_KV_LAYERS, uint32_t(0));
|
||||
ms.add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, n_embd_head);
|
||||
ms.add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, n_embd_head);
|
||||
ms.add_kv(LLM_KV_ROPE_FREQ_BASE_SWA, 10000.0f);
|
||||
// SWA pattern: every 5th layer is full attention (matches E2B layer_types)
|
||||
ms.add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, uint32_t(5));
|
||||
} else if (arch == LLM_ARCH_MIMO2 || arch == LLM_ARCH_STEP35) {
|
||||
std::vector<uint32_t> pattern;
|
||||
pattern.reserve(n_layer);
|
||||
for (uint32_t il = 0; il < n_layer; il++) {
|
||||
@@ -429,6 +442,9 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
|
||||
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
|
||||
continue;
|
||||
}
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
continue; // FIXME: ISWA KV cache initialization needs more fixture params
|
||||
}
|
||||
for (bool moe : {false, true}) {
|
||||
if (moe && !moe_implemented(arch)) {
|
||||
continue;
|
||||
@@ -510,6 +526,9 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
|
||||
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
|
||||
continue;
|
||||
}
|
||||
if (arch == LLM_ARCH_GEMMA4) {
|
||||
continue; // FIXME: ISWA KV cache initialization needs more fixture params
|
||||
}
|
||||
|
||||
const bool encode = arch == LLM_ARCH_T5 || arch == LLM_ARCH_DREAM || arch == LLM_ARCH_LLADA || arch == LLM_ARCH_LLADA_MOE || arch == LLM_ARCH_RND1;
|
||||
for (bool moe : {false, true}) {
|
||||
|
||||
@@ -1014,7 +1014,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
model.hf_file = params.hf_file[i];
|
||||
}
|
||||
|
||||
auto download_result = common_download_model(model, params.hf_token);
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = params.hf_token;
|
||||
auto download_result = common_download_model(model, opts);
|
||||
if (download_result.model_path.empty()) {
|
||||
fprintf(stderr, "error: failed to download model from HuggingFace\n");
|
||||
exit(1);
|
||||
|
||||
@@ -18,6 +18,7 @@ add_library(mtmd
|
||||
models/cogvlm.cpp
|
||||
models/conformer.cpp
|
||||
models/dotsocr.cpp
|
||||
models/gemma4a.cpp
|
||||
models/gemma4v.cpp
|
||||
models/glm4v.cpp
|
||||
models/hunyuanocr.cpp
|
||||
@@ -32,6 +33,7 @@ add_library(mtmd
|
||||
models/pixtral.cpp
|
||||
models/qwen2vl.cpp
|
||||
models/qwen3vl.cpp
|
||||
models/qwen3a.cpp
|
||||
models/step3vl.cpp
|
||||
models/siglip.cpp
|
||||
models/whisper-enc.cpp
|
||||
|
||||
@@ -135,6 +135,8 @@
|
||||
|
||||
// ultravox
|
||||
#define TN_CONV1D "a.conv1d.%d.%s"
|
||||
#define TN_CONV2D "a.conv2d.%d.%s"
|
||||
#define TN_CONV_OUT "a.conv_out.%s"
|
||||
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
|
||||
#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer
|
||||
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
|
||||
@@ -181,6 +183,21 @@
|
||||
#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s"
|
||||
#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s"
|
||||
|
||||
// gemma4 audio conformer
|
||||
#define TN_A_MM_INP_PROJ "mm.a.input_projection.%s"
|
||||
#define TN_A_MM_SOFT_EMB_N "mm.a.soft_emb_norm.%s"
|
||||
#define TN_A_INP_PROJ "a.input_projection.%s"
|
||||
#define TN_A_CONV1D "a.conv1d.%d.%s"
|
||||
#define TN_A_CONV1D_NORM "a.conv1d.%d.norm.%s"
|
||||
#define TN_A_OUT_PROJ "a.pre_encode.out.%s"
|
||||
#define TN_A_ATTN_PRE_NORM "%s.blk.%d.attn_pre_norm.%s"
|
||||
#define TN_A_ATTN_POST_NORM "%s.blk.%d.attn_post_norm.%s"
|
||||
#define TN_A_ATTN_K_REL "%s.blk.%d.attn_k_rel.%s"
|
||||
#define TN_A_PER_DIM_SCALE "%s.blk.%d.per_dim_scale.%s"
|
||||
#define TN_A_PER_DIM_K_SCALE "%s.blk.%d.per_dim_k_scale.%s"
|
||||
#define TN_A_FFN_POST_NORM "%s.blk.%d.ffn_post_norm.%s"
|
||||
#define TN_A_FFN_POST_NORM_1 "%s.blk.%d.ffn_post_norm_1.%s"
|
||||
|
||||
// mobilenetv5 (gemma3n) definitions
|
||||
#define TN_MNV5_STEM_CONV "v.conv_stem.conv.weight"
|
||||
#define TN_MNV5_STEM_BIAS "v.conv_stem.conv.bias"
|
||||
@@ -256,9 +273,11 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_INTERNVL,
|
||||
PROJECTOR_TYPE_LLAMA4,
|
||||
PROJECTOR_TYPE_QWEN2A,
|
||||
PROJECTOR_TYPE_QWEN3A,
|
||||
PROJECTOR_TYPE_GLMA,
|
||||
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
|
||||
PROJECTOR_TYPE_VOXTRAL,
|
||||
PROJECTOR_TYPE_MERALION,
|
||||
PROJECTOR_TYPE_MUSIC_FLAMINGO,
|
||||
PROJECTOR_TYPE_LFM2,
|
||||
PROJECTOR_TYPE_KIMIVL,
|
||||
@@ -299,9 +318,11 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
||||
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
||||
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
|
||||
{ PROJECTOR_TYPE_QWEN3A, "qwen3a"},
|
||||
{ PROJECTOR_TYPE_GLMA, "glma"},
|
||||
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
|
||||
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
|
||||
{ PROJECTOR_TYPE_MERALION, "meralion"},
|
||||
{ PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"},
|
||||
{ PROJECTOR_TYPE_LFM2, "lfm2"},
|
||||
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
|
||||
|
||||
+28
-1
@@ -217,6 +217,13 @@ struct clip_layer {
|
||||
ggml_tensor * conv_pw2_w = nullptr;
|
||||
ggml_tensor * conv_pw2_b = nullptr;
|
||||
|
||||
// gemma4 audio conformer per-layer
|
||||
ggml_tensor * attn_pre_norm_w = nullptr;
|
||||
ggml_tensor * attn_k_rel_w = nullptr;
|
||||
ggml_tensor * per_dim_scale_w = nullptr;
|
||||
ggml_tensor * per_dim_k_scale_w = nullptr;
|
||||
ggml_tensor * ff_post_norm_1_w = nullptr;
|
||||
|
||||
bool has_deepstack() const {
|
||||
return deepstack_fc1_w != nullptr;
|
||||
}
|
||||
@@ -406,10 +413,20 @@ struct clip_model {
|
||||
ggml_tensor * conv1d_1_b = nullptr;
|
||||
ggml_tensor * conv1d_2_w = nullptr;
|
||||
ggml_tensor * conv1d_2_b = nullptr;
|
||||
ggml_tensor * conv_out_w = nullptr;
|
||||
ggml_tensor * conv_out_b = nullptr;
|
||||
ggml_tensor * mm_norm_pre_w = nullptr;
|
||||
ggml_tensor * mm_norm_pre_b = nullptr;
|
||||
ggml_tensor * mm_norm_mid_w = nullptr;
|
||||
|
||||
// qwen3a
|
||||
ggml_tensor * conv2d_1_w = nullptr;
|
||||
ggml_tensor * conv2d_1_b = nullptr;
|
||||
ggml_tensor * conv2d_2_w = nullptr;
|
||||
ggml_tensor * conv2d_2_b = nullptr;
|
||||
ggml_tensor * conv2d_3_w = nullptr;
|
||||
ggml_tensor * conv2d_3_b = nullptr;
|
||||
|
||||
// cogvlm
|
||||
ggml_tensor * mm_post_fc_norm_w = nullptr;
|
||||
ggml_tensor * mm_post_fc_norm_b = nullptr;
|
||||
@@ -459,6 +476,15 @@ struct clip_model {
|
||||
};
|
||||
std::map<std::string, clamp_info> clamp_info_map;
|
||||
|
||||
// gemma4 audio conformer
|
||||
std::array<ggml_tensor *, 2> sscp_conv_w = {nullptr};
|
||||
std::array<ggml_tensor *, 2> sscp_conv_b = {nullptr};
|
||||
std::array<ggml_tensor *, 2> sscp_norm_w = {nullptr};
|
||||
ggml_tensor * sscp_inp_proj_w = nullptr;
|
||||
ggml_tensor * sscp_inp_proj_b = nullptr;
|
||||
ggml_tensor * audio_out_proj_w = nullptr;
|
||||
ggml_tensor * audio_out_proj_b = nullptr;
|
||||
|
||||
bool audio_has_avgpool() const {
|
||||
return proj_type == PROJECTOR_TYPE_QWEN2A
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL
|
||||
@@ -467,7 +493,8 @@ struct clip_model {
|
||||
|
||||
bool audio_has_stack_frames() const {
|
||||
return proj_type == PROJECTOR_TYPE_ULTRAVOX
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL
|
||||
|| proj_type == PROJECTOR_TYPE_MERALION;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
+222
-5
@@ -890,6 +890,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_whisper_enc>(ctx, img);
|
||||
@@ -930,10 +931,18 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
builder = std::make_unique<clip_graph_conformer>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_gemma4a>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_glm4v>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_qwen3a>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YOUTUVL:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_youtuvl>(ctx, img);
|
||||
@@ -1397,12 +1406,15 @@ struct clip_model_loader {
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
{
|
||||
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX ||
|
||||
model.proj_type == PROJECTOR_TYPE_VOXTRAL ||
|
||||
model.proj_type == PROJECTOR_TYPE_MERALION ||
|
||||
model.proj_type == PROJECTOR_TYPE_GLMA;
|
||||
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
|
||||
hparams.ffn_op = FFN_GELU_ERF;
|
||||
@@ -1456,6 +1468,16 @@ struct clip_model_loader {
|
||||
hparams.audio_window_len = 400;
|
||||
hparams.audio_hop_len = 160;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
// Gemma4 feature_extraction_gemma4.py:
|
||||
// frame_length_ms=20 -> 320 samples, n_fft=512, hop=10ms -> 160
|
||||
hparams.audio_chunk_len = 0; // no fixed-length padding
|
||||
hparams.audio_sample_rate = 16000;
|
||||
hparams.audio_n_fft = 512;
|
||||
hparams.audio_window_len = 320; // 20ms frame (NOT 25ms/400)
|
||||
hparams.audio_hop_len = 160;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
{
|
||||
hparams.image_pad_color = {127, 127, 127};
|
||||
@@ -1558,16 +1580,21 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
// helper function
|
||||
std::unordered_set<std::string> loaded_tensor_names;
|
||||
auto get_tensor = [&](const std::string & name, bool required = true) {
|
||||
// Each tensor should only be loaded once; duplicates indicate a bug
|
||||
if (loaded_tensor_names.count(name)) {
|
||||
throw std::runtime_error(string_format("%s: tensor already loaded: %s\n", __func__, name.c_str()));
|
||||
}
|
||||
ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
|
||||
if (!cur && required) {
|
||||
throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
|
||||
}
|
||||
if (cur) {
|
||||
tensors_to_load.push_back(cur);
|
||||
// add tensors to context
|
||||
ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
|
||||
ggml_set_name(data_tensor, cur->name);
|
||||
loaded_tensor_names.insert(name);
|
||||
cur = data_tensor;
|
||||
}
|
||||
return cur;
|
||||
@@ -2017,6 +2044,30 @@ struct clip_model_loader {
|
||||
model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
|
||||
model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
{
|
||||
// Whisper encoder conv layers
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
|
||||
model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
|
||||
model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
|
||||
// MERaLiON adaptor: 4 linear layers + ln_pre
|
||||
// linear_0 = frame compression (19200->6400) + SiLU
|
||||
// linear_1 = gate_proj (6400->6400) for GLU
|
||||
// linear_2 = pool_proj (6400->6400) for GLU
|
||||
// linear_3 = out_proj (6400->3584)
|
||||
model.mm_0_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "weight"));
|
||||
model.mm_0_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "bias"));
|
||||
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
|
||||
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
|
||||
model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
|
||||
model.mm_3_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "weight"));
|
||||
model.mm_3_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "bias"));
|
||||
// ln_speech (LayerNorm before adaptor)
|
||||
model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
|
||||
model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
@@ -2026,6 +2077,20 @@ struct clip_model_loader {
|
||||
model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
|
||||
model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
{
|
||||
model.conv2d_1_w = get_tensor(string_format(TN_CONV2D, 1, "weight"));
|
||||
model.conv2d_1_b = get_tensor(string_format(TN_CONV2D, 1, "bias"));
|
||||
model.conv2d_2_w = get_tensor(string_format(TN_CONV2D, 2, "weight"));
|
||||
model.conv2d_2_b = get_tensor(string_format(TN_CONV2D, 2, "bias"));
|
||||
model.conv2d_3_w = get_tensor(string_format(TN_CONV2D, 3, "weight"));
|
||||
model.conv2d_3_b = get_tensor(string_format(TN_CONV2D, 3, "bias"));
|
||||
model.conv_out_w = get_tensor(string_format(TN_CONV_OUT, "weight")); // no bias
|
||||
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
|
||||
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
|
||||
model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
@@ -2159,6 +2224,76 @@ struct clip_model_loader {
|
||||
model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight"));
|
||||
model.mm_fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
for (int i = 0; i < 2; i++) {
|
||||
model.sscp_conv_w[i] = get_tensor(string_format(TN_A_CONV1D, i, "weight"));
|
||||
model.sscp_conv_b[i] = get_tensor(string_format(TN_A_CONV1D, i, "bias"), false);
|
||||
model.sscp_norm_w[i] = get_tensor(string_format(TN_A_CONV1D_NORM, i, "weight"), false);
|
||||
}
|
||||
model.sscp_inp_proj_w = get_tensor(string_format(TN_A_INP_PROJ, "weight"));
|
||||
model.sscp_inp_proj_b = get_tensor(string_format(TN_A_INP_PROJ, "bias"), false);
|
||||
model.audio_out_proj_w = get_tensor(string_format(TN_A_OUT_PROJ, "weight"), false);
|
||||
model.audio_out_proj_b = get_tensor(string_format(TN_A_OUT_PROJ, "bias"), false);
|
||||
// audio multimodal embedder (mm.a.* namespace, not mm.*)
|
||||
model.mm_soft_emb_norm_w = get_tensor(string_format(TN_A_MM_SOFT_EMB_N, "weight"), false);
|
||||
model.mm_input_proj_w = get_tensor(string_format(TN_A_MM_INP_PROJ, "weight"), false);
|
||||
|
||||
// Per-layer tensors NOT loaded by the generic loop above
|
||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||
auto & layer = model.layers[il];
|
||||
|
||||
// Gemma4 audio conformer-specific tensors
|
||||
layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"));
|
||||
layer.attn_pre_norm_w = get_tensor(string_format(TN_A_ATTN_PRE_NORM, prefix, il, "weight"), false);
|
||||
layer.per_dim_scale_w = get_tensor(string_format(TN_A_PER_DIM_SCALE, prefix, il, "weight"), false);
|
||||
layer.per_dim_k_scale_w = get_tensor(string_format(TN_A_PER_DIM_K_SCALE, prefix, il, "weight"), false);
|
||||
layer.attn_k_rel_w = get_tensor(string_format(TN_A_ATTN_K_REL, prefix, il, "weight"), false);
|
||||
|
||||
// Convolution module
|
||||
// Note: conv_norm / norm_conv are swapped in GGUF due to
|
||||
// upstream tensor_mapping.py, so we load them in reverse order
|
||||
layer.norm_conv_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight"), false);
|
||||
layer.norm_conv_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"), false);
|
||||
layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight"));
|
||||
layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias"), false);
|
||||
layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight"));
|
||||
layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias"), false);
|
||||
layer.conv_norm_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight"), false);
|
||||
layer.conv_norm_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"), false);
|
||||
layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight"));
|
||||
layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias"), false);
|
||||
|
||||
// FFN2 (second half-step)
|
||||
layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight"));
|
||||
layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight"));
|
||||
layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias"), false);
|
||||
layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight"));
|
||||
layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"), false);
|
||||
layer.ff_post_norm_1_w = get_tensor(string_format(TN_A_FFN_POST_NORM_1, prefix, il, "weight"), false);
|
||||
}
|
||||
|
||||
// Load clamp info for ClippableLinear AFTER all tensors are loaded
|
||||
for (auto * tensor : tensors_to_load) {
|
||||
std::string name = tensor->name;
|
||||
if (string_ends_with(name, ".weight")) {
|
||||
std::string name_inp_max = name;
|
||||
std::string name_inp_min = name;
|
||||
std::string name_out_max = name;
|
||||
std::string name_out_min = name;
|
||||
string_replace_all(name_inp_max, ".weight", ".input_max");
|
||||
string_replace_all(name_inp_min, ".weight", ".input_min");
|
||||
string_replace_all(name_out_max, ".weight", ".output_max");
|
||||
string_replace_all(name_out_min, ".weight", ".output_min");
|
||||
model.clamp_info_map[name] = {
|
||||
get_scalar(name_inp_max, FLT_MAX),
|
||||
get_scalar(name_inp_min, -FLT_MAX),
|
||||
get_scalar(name_out_max, FLT_MAX),
|
||||
get_scalar(name_out_min, -FLT_MAX)
|
||||
};
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
{
|
||||
for (int i : {0, 2, 3, 5, 6}) {
|
||||
@@ -2219,7 +2354,10 @@ struct clip_model_loader {
|
||||
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
for (auto & t : tensors_to_load) {
|
||||
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
|
||||
const size_t offset = tensor_offset[t->name];
|
||||
GGML_ASSERT(cur && "tensor not found in ctx_data");
|
||||
auto it_off = tensor_offset.find(t->name);
|
||||
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
|
||||
const size_t offset = it_off->second;
|
||||
fin.seekg(offset, std::ios::beg);
|
||||
if (!fin) {
|
||||
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
|
||||
@@ -2239,6 +2377,7 @@ struct clip_model_loader {
|
||||
|
||||
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
struct support_info_op {
|
||||
@@ -2511,8 +2650,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||
|
||||
// TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors
|
||||
// we can remove this check when we implement audio support for Gemma 3N
|
||||
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV
|
||||
|| ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA4V;
|
||||
skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV;
|
||||
}
|
||||
|
||||
if (loader.has_audio && !skip_audio) {
|
||||
@@ -2809,6 +2947,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
{
|
||||
n_patches = img->nx;
|
||||
@@ -2828,6 +2967,15 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
n_patches /= 2;
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
{
|
||||
// 3x stride-2 conv2d: each step is floor((n-1)/2)+1
|
||||
int n = img->nx;
|
||||
n = (n - 1) / 2 + 1;
|
||||
n = (n - 1) / 2 + 1;
|
||||
n = (n - 1) / 2 + 1;
|
||||
n_patches = n;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
{
|
||||
n_patches = img->nx;
|
||||
@@ -2865,6 +3013,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
{
|
||||
n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
// Two Conv2D stride-2: O = floor((I + 2p - k) / s) + 1, p=1, k=3, s=2
|
||||
// O = floor((I - 1) / 2) + 1
|
||||
int n = img->nx;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
n = (n - 1) / 2 + 1;
|
||||
}
|
||||
n_patches = n;
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported projector type");
|
||||
}
|
||||
@@ -3294,10 +3452,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
case PROJECTOR_TYPE_PHI4:
|
||||
@@ -3323,6 +3483,56 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
set_input_i32("pos_w", pos_data);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
{
|
||||
GGML_ASSERT(imgs.entries.size() == 1);
|
||||
const auto & img0 = imgs.entries.front();
|
||||
// Compute n_pos matching SSCP output: two stride-2 convs
|
||||
int n_pos = img0->nx;
|
||||
for (int i = 0; i < 2; i++) { n_pos = (n_pos - 1) / 2 + 1; }
|
||||
|
||||
// Chunked local attention: blocked causal mask and RPE
|
||||
const int chunk_size = 12;
|
||||
const int max_past = 12;
|
||||
const int context_size = chunk_size + max_past;
|
||||
const int num_blocks = (n_pos + chunk_size - 1) / chunk_size;
|
||||
|
||||
// Blocked causal attention mask: [context_size, chunk_size, num_blocks]
|
||||
{
|
||||
std::vector<float> mask(context_size * chunk_size * num_blocks, -1e9f);
|
||||
for (int b = 0; b < num_blocks; b++) {
|
||||
for (int q = 0; q < chunk_size; q++) {
|
||||
int gq = b * chunk_size + q;
|
||||
for (int k = 0; k < context_size; k++) {
|
||||
int gk = b * chunk_size - max_past + k;
|
||||
if (gq < n_pos && gk >= 0 && gk < n_pos && gk <= gq && (gq - gk) < max_past) {
|
||||
mask[k + q * context_size + b * context_size * chunk_size] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
set_input_f32("kq_mask", mask);
|
||||
}
|
||||
|
||||
// Sinusoidal RPE: 13 positions [12, 11, ..., 0]
|
||||
{
|
||||
const int n_embd = ctx->model.hparams.n_embd;
|
||||
const int num_timescales = n_embd / 2;
|
||||
const float log_timescale_increment = logf(10000.0f) / std::max(num_timescales - 1, 1);
|
||||
const int rpe_len = max_past + 1;
|
||||
std::vector<float> pos_emb(n_embd * rpe_len, 0.0f);
|
||||
for (int p = 0; p < rpe_len; p++) {
|
||||
float position = (float)(max_past - p);
|
||||
for (int i = 0; i < num_timescales; i++) {
|
||||
float inv_ts = expf(-(float)i * log_timescale_increment);
|
||||
float scaled = position * inv_ts;
|
||||
pos_emb[p * n_embd + i] = sinf(scaled);
|
||||
pos_emb[p * n_embd + i + num_timescales] = cosf(scaled);
|
||||
}
|
||||
}
|
||||
set_input_f32("pos_emb", pos_emb);
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
{
|
||||
GGML_ASSERT(imgs.entries.size() == 1);
|
||||
@@ -3463,6 +3673,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
return ctx->model.mm_3_w->ne[1]; // out_proj output dim
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
return ctx->model.mm_3_w->ne[1];
|
||||
@@ -3470,8 +3682,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->model.mm_model_proj->ne[1];
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
return ctx->model.mm_fc_w->ne[1];
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_PADDLEOCR:
|
||||
@@ -3485,6 +3698,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->model.mm_fc_w->ne[1];
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
return ctx->model.position_embeddings->ne[0];
|
||||
case PROJECTOR_TYPE_GEMMA4A:
|
||||
return ctx->model.hparams.projection_dim;
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
return ctx->model.mm_ffn_down_w->ne[1];
|
||||
default:
|
||||
@@ -3521,8 +3736,10 @@ bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
||||
switch (ctx->proj_type()) {
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN3A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_MERALION:
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
return true;
|
||||
default:
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
/**
|
||||
* Gemma 4 Audio Conformer Encoder (clip_graph_gemma4a)
|
||||
*
|
||||
* Architecture: Conformer with dual half-step FFN, full self-attention
|
||||
* with sinusoidal RPE, depthwise light conv, and output projection.
|
||||
*/
|
||||
|
||||
#include "models.h"
|
||||
#include <cmath>
|
||||
|
||||
ggml_cgraph * clip_graph_gemma4a::build() {
|
||||
const float res_weight = 0.5f;
|
||||
const float norm_eps = 1e-6f;
|
||||
|
||||
// 1. Input
|
||||
ggml_tensor * inp = build_inp_raw(1);
|
||||
auto * cur = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
||||
|
||||
// 2. Subsampling Conv2D (symmetric padding=1, matching PyTorch)
|
||||
{
|
||||
for (int i = 0; i < 2; i++) {
|
||||
cur = ggml_conv_2d(ctx0, model.sscp_conv_w[i], cur, 2, 2, 1, 1, 1, 1);
|
||||
if (model.sscp_conv_b[i]) {
|
||||
cur = ggml_add(ctx0, cur, model.sscp_conv_b[i]);
|
||||
}
|
||||
// nn.LayerNorm(channels): permute ch to ne[0], normalize, permute back
|
||||
if (model.sscp_norm_w[i]) {
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
||||
cur = ggml_norm(ctx0, cur, norm_eps);
|
||||
cur = ggml_mul(ctx0, cur, model.sscp_norm_w[i]);
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
||||
}
|
||||
cur = ggml_relu(ctx0, cur);
|
||||
}
|
||||
// Flatten [freq, time, ch, 1] -> [ch*freq, time]
|
||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
|
||||
if (model.sscp_inp_proj_w) {
|
||||
cur = build_mm(model.sscp_inp_proj_w, cur);
|
||||
if (model.sscp_inp_proj_b) {
|
||||
cur = ggml_add(ctx0, cur, model.sscp_inp_proj_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t n_pos = cur->ne[1];
|
||||
|
||||
// Chunked local attention parameters
|
||||
const int64_t C = 12; // chunk_size
|
||||
const int64_t P = 12; // max_past_horizon (context_left - 1)
|
||||
const int64_t S = C + P; // context_size = 24
|
||||
const int64_t R = P + 1; // RPE positions = 13
|
||||
const int64_t B = (n_pos + C - 1) / C; // num_blocks
|
||||
const int64_t Np = B * C; // padded sequence length
|
||||
const int64_t pad_seq = Np - n_pos;
|
||||
|
||||
// Input tensors: blocked RPE and blocked attention mask
|
||||
ggml_tensor * pos_emb = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_head * d_head, R);
|
||||
ggml_set_name(pos_emb, "pos_emb");
|
||||
ggml_set_input(pos_emb);
|
||||
|
||||
ggml_tensor * kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, S, C, B);
|
||||
ggml_set_name(kq_mask, "kq_mask");
|
||||
ggml_set_input(kq_mask);
|
||||
|
||||
// 3. Conformer Blocks
|
||||
for (int il = 0; il < hparams.n_layer; il++) {
|
||||
const auto & layer = model.layers[il];
|
||||
auto * residual = cur;
|
||||
|
||||
// FFN 1 (half-step)
|
||||
if (layer.ff_norm_w && layer.ff_up_w && layer.ff_down_w) {
|
||||
cur = build_norm(cur, layer.ff_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
cur = build_ffn(cur,
|
||||
layer.ff_up_w, nullptr, nullptr, nullptr,
|
||||
layer.ff_down_w, nullptr, FFN_SILU, il);
|
||||
if (layer.ff_post_norm_w) {
|
||||
cur = build_norm(cur, layer.ff_post_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
}
|
||||
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, res_weight));
|
||||
}
|
||||
|
||||
// Chunked local self-attention with RPE
|
||||
if (layer.q_w && layer.k_w && layer.v_w && layer.o_w) {
|
||||
const float q_scale = (1.0f / sqrtf((float)d_head)) / logf(2.0f);
|
||||
const float k_scale = logf(1.0f + expf(1.0f)) / logf(2.0f);
|
||||
const float softcap = 50.0f;
|
||||
|
||||
ggml_tensor * attn_norm_w = layer.attn_pre_norm_w ? layer.attn_pre_norm_w : layer.ln_1_w;
|
||||
cur = attn_norm_w
|
||||
? build_norm(residual, attn_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il)
|
||||
: residual;
|
||||
|
||||
ggml_tensor * Qcur = build_mm(layer.q_w, cur);
|
||||
ggml_tensor * Kcur = build_mm(layer.k_w, cur);
|
||||
ggml_tensor * Vcur = build_mm(layer.v_w, cur);
|
||||
|
||||
// [n_embd, n_pos] -> [D, H, N]
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
|
||||
|
||||
// Q/K scaling
|
||||
Qcur = ggml_scale(ctx0, Qcur, q_scale);
|
||||
if (layer.per_dim_scale_w) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, ggml_reshape_3d(ctx0, layer.per_dim_scale_w, d_head, 1, 1));
|
||||
}
|
||||
Kcur = ggml_scale(ctx0, Kcur, k_scale);
|
||||
if (layer.per_dim_k_scale_w) {
|
||||
Kcur = ggml_mul(ctx0, Kcur, ggml_reshape_3d(ctx0, layer.per_dim_k_scale_w, d_head, 1, 1));
|
||||
}
|
||||
|
||||
// Q blocking: [D, H, N] -> pad to Np -> reshape [D, H, C, B]
|
||||
// ggml permute: ne[ax_i] = src->ne[i], so (0,3,1,2) sends H->3, C->1, B->2
|
||||
Qcur = ggml_pad(ctx0, Qcur, 0, 0, pad_seq, 0); // [D, H, Np]
|
||||
Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, C, B); // [D, H, C, B]
|
||||
Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 3, 1, 2)); // [D, C, B, H]
|
||||
|
||||
// K/V block context extraction via overlapping view:
|
||||
// Pad to S*B elements, roll right by P to create left-padding,
|
||||
// then view with stride C in the block dimension (overlapping windows).
|
||||
auto extract_blocks = [&](ggml_tensor * t) -> ggml_tensor * {
|
||||
// [D, H, N] -> pad to S*B -> roll right by P -> cont (materialize)
|
||||
const int64_t pad_kv = S * B - n_pos;
|
||||
t = ggml_pad(ctx0, t, 0, 0, pad_kv, 0); // [D, H, S*B]
|
||||
t = ggml_roll(ctx0, t, 0, 0, P, 0); // left-pad by P
|
||||
t = ggml_cont(ctx0, t); // materialize roll (removes view offset)
|
||||
// Overlapping view: stride for B dim is C positions, not S
|
||||
// ne = [D, H, S, B], data_size = D*H*S*B*sizeof = source_nbytes (exact fit)
|
||||
// nb1=D*sizeof, nb2=D*H*sizeof, nb3=C*D*H*sizeof (overlap: C < S)
|
||||
t = ggml_view_4d(ctx0, t, d_head, n_head, S, B,
|
||||
t->nb[1], t->nb[2], C * t->nb[2], 0);
|
||||
t = ggml_cont(ctx0, t); // materialize overlapping windows
|
||||
return t;
|
||||
};
|
||||
|
||||
ggml_tensor * Kblk = extract_blocks(Kcur);
|
||||
// [D, H, S, B] -> [D, S, B, H] via permute(0,3,1,2)
|
||||
Kblk = ggml_cont(ctx0, ggml_permute(ctx0, Kblk, 0, 3, 1, 2));
|
||||
|
||||
ggml_tensor * Vblk = extract_blocks(Vcur);
|
||||
// [D, H, S, B] -> [S, D, B, H] via permute(1,3,0,2)
|
||||
Vblk = ggml_cont(ctx0, ggml_permute(ctx0, Vblk, 1, 3, 0, 2));
|
||||
|
||||
// Content attention: Q @ K^T
|
||||
// Kblk=[D,S,B,H], Qcur=[D,C,B,H] -> mul_mat contracts on D -> [S,C,B,H]
|
||||
ggml_tensor * matrix_ac = ggml_mul_mat(ctx0, Kblk, Qcur);
|
||||
|
||||
// Relative position attention
|
||||
if (layer.attn_k_rel_w) {
|
||||
// RPE: [n_embd, R] -> project -> [D, H, R] -> [D, R, H]
|
||||
auto * p = ggml_mul_mat(ctx0, layer.attn_k_rel_w, pos_emb);
|
||||
p = ggml_reshape_3d(ctx0, p, d_head, n_head, R);
|
||||
p = ggml_cont(ctx0, ggml_permute(ctx0, p, 0, 2, 1, 3)); // [D, R, H]
|
||||
|
||||
// Q_flat @ RPE^T: [D, C*B, H] @ [D, R, H] -> [R, C*B, H]
|
||||
auto * Q_flat = ggml_reshape_3d(ctx0, Qcur, d_head, C * B, n_head);
|
||||
auto * matrix_bd = ggml_mul_mat(ctx0, p, Q_flat); // [R, C*B, H]
|
||||
matrix_bd = ggml_reshape_4d(ctx0, matrix_bd, R, C, B, n_head); // [R, C, B, H]
|
||||
|
||||
// Blocked relative shift (appendix B of Transformer-XL)
|
||||
{
|
||||
matrix_bd = ggml_pad(ctx0, matrix_bd, S + 1 - R, 0, 0, 0); // [S+1, C, B, H]
|
||||
matrix_bd = ggml_reshape_3d(ctx0, matrix_bd, (S + 1) * C, B, n_head);
|
||||
matrix_bd = ggml_view_3d(ctx0, matrix_bd,
|
||||
C * S, B, n_head,
|
||||
matrix_bd->nb[1], matrix_bd->nb[2], 0);
|
||||
matrix_bd = ggml_cont(ctx0, matrix_bd); // [C*S, B, H]
|
||||
matrix_bd = ggml_reshape_4d(ctx0, matrix_bd, S, C, B, n_head); // [S, C, B, H]
|
||||
}
|
||||
|
||||
matrix_ac = ggml_add(ctx0, matrix_ac, matrix_bd);
|
||||
}
|
||||
|
||||
auto * scores = matrix_ac; // [S, C, B, H]
|
||||
|
||||
// Softcap
|
||||
scores = ggml_scale(ctx0, scores, 1.0f / softcap);
|
||||
scores = ggml_tanh(ctx0, scores);
|
||||
scores = ggml_scale(ctx0, scores, softcap);
|
||||
|
||||
// Blocked attention mask: [S, C, B] broadcasts over H
|
||||
scores = ggml_add(ctx0, scores, kq_mask);
|
||||
|
||||
ggml_tensor * attn = ggml_soft_max(ctx0, scores);
|
||||
|
||||
// attn @ V: [S,C,B,H] @ [S,D,B,H] -> [D,C,B,H]
|
||||
ggml_tensor * x = ggml_mul_mat(ctx0, Vblk, attn);
|
||||
|
||||
// [D,C,B,H] -> [D,H,C,B] via permute(0,2,3,1) -> flatten -> trim
|
||||
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 3, 1));
|
||||
x = ggml_cont_2d(ctx0, x, d_head * n_head, C * B);
|
||||
if (pad_seq > 0) {
|
||||
x = ggml_view_2d(ctx0, x, d_head * n_head, n_pos, x->nb[1], 0);
|
||||
x = ggml_cont(ctx0, x);
|
||||
}
|
||||
|
||||
x = build_mm(layer.o_w, x);
|
||||
if (layer.o_b) { x = ggml_add(ctx0, x, layer.o_b); }
|
||||
|
||||
if (layer.attn_post_norm_w) {
|
||||
x = build_norm(x, layer.attn_post_norm_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
}
|
||||
residual = ggml_add(ctx0, residual, x);
|
||||
}
|
||||
|
||||
// Convolution Module
|
||||
if (layer.norm_conv_w && layer.conv_pw1_w && layer.conv_dw_w && layer.conv_pw2_w) {
|
||||
cur = build_norm(residual, layer.norm_conv_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
auto * x = build_mm(layer.conv_pw1_w, cur);
|
||||
|
||||
// GLU
|
||||
{
|
||||
int64_t d = x->ne[0] / 2;
|
||||
ggml_tensor * gate = ggml_sigmoid(ctx0,
|
||||
ggml_cont(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], d * x->nb[0])));
|
||||
x = ggml_mul(ctx0,
|
||||
ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], 0), gate);
|
||||
x = ggml_cont(ctx0, ggml_transpose(ctx0, x));
|
||||
}
|
||||
|
||||
// Causal depthwise Conv1D via ggml_ssm_conv (pad+roll for left-only padding).
|
||||
x = ggml_pad(ctx0, x, 4, 0, 0, 0);
|
||||
x = ggml_roll(ctx0, x, 4, 0, 0, 0);
|
||||
x = ggml_ssm_conv(ctx0, x, layer.conv_dw_w);
|
||||
if (layer.conv_dw_b) {
|
||||
x = ggml_add(ctx0, x, layer.conv_dw_b);
|
||||
}
|
||||
|
||||
if (layer.conv_norm_w) {
|
||||
x = ggml_rms_norm(ctx0, x, norm_eps);
|
||||
x = ggml_mul(ctx0, x, layer.conv_norm_w);
|
||||
}
|
||||
x = ggml_silu(ctx0, x);
|
||||
x = build_mm(layer.conv_pw2_w, x);
|
||||
residual = ggml_add(ctx0, residual, x);
|
||||
}
|
||||
|
||||
// FFN 2 (half-step)
|
||||
if (layer.ff_norm_1_w && layer.ff_up_1_w && layer.ff_down_1_w) {
|
||||
cur = build_norm(residual, layer.ff_norm_1_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
cur = build_ffn(cur,
|
||||
layer.ff_up_1_w, nullptr, nullptr, nullptr,
|
||||
layer.ff_down_1_w, nullptr, FFN_SILU, il);
|
||||
if (layer.ff_post_norm_1_w) {
|
||||
cur = build_norm(cur, layer.ff_post_norm_1_w, nullptr, NORM_TYPE_RMS, norm_eps, il);
|
||||
}
|
||||
residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, res_weight));
|
||||
}
|
||||
|
||||
// Layer output norm
|
||||
cur = layer.ln_2_w
|
||||
? build_norm(residual, layer.ln_2_w, nullptr, NORM_TYPE_RMS, norm_eps, il)
|
||||
: residual;
|
||||
|
||||
}
|
||||
|
||||
// 4. Output Projection
|
||||
if (model.audio_out_proj_w) {
|
||||
cur = build_mm(model.audio_out_proj_w, cur);
|
||||
if (model.audio_out_proj_b) {
|
||||
cur = ggml_add(ctx0, cur, model.audio_out_proj_b);
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Audio Multimodal Embedder
|
||||
cur = ggml_rms_norm(ctx0, cur, norm_eps);
|
||||
if (model.mm_soft_emb_norm_w) {
|
||||
cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
|
||||
}
|
||||
if (model.mm_input_proj_w) {
|
||||
cur = build_mm(model.mm_input_proj_w, cur);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_tensor * clip_graph_gemma4a::build_mm(ggml_tensor * w, ggml_tensor * x) const {
|
||||
auto it = model.clamp_info_map.find(w->name);
|
||||
if (it == model.clamp_info_map.end()) {
|
||||
return ggml_mul_mat(ctx0, w, x);
|
||||
}
|
||||
const auto & ci = it->second;
|
||||
ggml_tensor * clamped = ggml_clamp(ctx0, x, ci.inp_min, ci.inp_max);
|
||||
ggml_tensor * out = ggml_mul_mat(ctx0, w, clamped);
|
||||
return ggml_clamp(ctx0, out, ci.out_min, ci.out_max);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user