Compare commits

..

1 Commits

Author SHA1 Message Date
YiChen Lv bec3083830 metal : per-op source split + parallel compile (#24021)
* preliminary extract common header

* op source split

* split metallib into 8 libs && load in parallel

* derive kernel->library routing from functionNames

* x-macro lib list + underscore filenames, dedup QK_NL, MRC fixes

* op source split 8 to 20

* improve robustness of source fallback

* clean up

* change bool -> atomic_bool

* only prepend headers that source actually includes

* no semaphore, use GCD global queue

* dedup library compile path, fix NSError lifetime, rename gla

* relocate upstream concat/rope_back/repeat kernel changes into split files

* move ggml-common.h from common.h into dequantize.h to shrink binary size

---------

Co-authored-by: lvyichen <lvyichen@stepfun.com>
2026-06-22 14:15:48 +03:00
232 changed files with 23373 additions and 24902 deletions
+1 -1
View File
@@ -10,7 +10,7 @@
# ggml-org/ggml-rpc : rgerganov
# ggml-org/ggml-sycl : arthw
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
# ggml-org/ggml-webgpu : reeselevine, yomaytk
# ggml-org/ggml-webgpu : reeselevine
# ggml-org/ggml-zdnn : taronaeo
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
# ggml-org/llama-mtmd : ngxson
+1 -3
View File
@@ -142,9 +142,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
+2
View File
@@ -80,6 +80,8 @@ add_library(${TARGET}
http.h
imatrix-loader.cpp
imatrix-loader.h
json-partial.cpp
json-partial.h
json-schema-to-grammar.cpp
llguidance.cpp
log.cpp
+5 -14
View File
@@ -301,8 +301,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
const common_download_opts & opts) {
handle_model_result result;
// TODO @ngxson : refactor this into a new common_model_download_context
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
} else if (!model.hf_repo.empty()) {
@@ -398,7 +396,7 @@ static bool parse_bool_value(const std::string & value) {
// CLI argument parsing functions
//
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
@@ -409,11 +407,6 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex,
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
opts.preset_only = handle_params.preset_only;
if (handle_params.callback) {
opts.callback = handle_params.callback;
}
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
// so we should not auto-discover mtp/mmproj siblings for them
@@ -591,19 +584,17 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}
const bool skip_model_download =
// server will call common_params_handle_models() later, so we skip it here
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
// export_graph_ops loads only metadata
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex, {});
common_params_handle_models(params, ctx_arg.ex);
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty()
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
&& !params.usage
&& !params.completion) {
throw std::invalid_argument("error: --model is required\n");
+1 -10
View File
@@ -1,7 +1,6 @@
#pragma once
#include "common.h"
#include "download.h"
#include <set>
#include <map>
@@ -130,19 +129,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
// see: https://github.com/ggml-org/llama.cpp/issues/18163
void common_params_add_preset_options(std::vector<common_arg> & args);
struct common_params_handle_models_params {
common_download_callback * callback = nullptr;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
};
// populate model paths (main model, mmproj, etc) from -hf if necessary
// return true if the model is ready to use
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
bool common_params_handle_models(
common_params & params,
llama_example curr_ex,
const common_params_handle_models_params & handle_params);
bool common_params_handle_models(common_params & params, llama_example curr_ex);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
+53 -103
View File
@@ -90,93 +90,41 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
return text;
}
common_chat_role common_chat_role_from_string(const std::string & role) {
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
return COMMON_CHAT_ROLE_UNKNOWN;
}
const char * common_chat_role_to_string(common_chat_role role) {
switch (role) {
case COMMON_CHAT_ROLE_SYSTEM: return "system";
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
case COMMON_CHAT_ROLE_USER: return "user";
case COMMON_CHAT_ROLE_TOOL: return "tool";
case COMMON_CHAT_ROLE_UNKNOWN: return "";
}
return "";
}
json common_chat_msg_delimiters::to_json() const {
json result = json::array();
for (const auto & d : delimiters) {
result.push_back({
{ "role", common_chat_role_to_string(d.role) },
{ "delimiter", d.delimiter },
});
}
return result;
}
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
common_chat_msg_delimiters result;
if (!delimiters.is_array()) {
return result;
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
if (delims.empty() || prompt.empty()) {
return {};
}
result.delimiters.reserve(delimiters.size());
for (const auto & d : delimiters) {
if (!d.is_object()) {
continue;
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
std::vector<std::string> all_delims;
std::vector<common_peg_parser> tagged_messages;
all_delims.reserve(delims.size());
tagged_messages.reserve(delims.size());
for (const auto & d : delims) {
all_delims.push_back(d.delimiter);
}
result.delimiters.push_back({
common_chat_role_from_string(d.value("role", std::string())),
d.value("delimiter", std::string()),
});
}
return result;
}
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
for (auto & d : delimiters) {
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
}
}
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
std::vector<std::pair<common_chat_role, size_t>> matches;
auto skip = skips.begin();
for (size_t i = 0; i < tokens.size();) {
if (skip != skips.end() && i == skip->first) {
i += skip->second;
++skip;
continue;
auto any_delim = p.until_one_of(all_delims);
for (const auto & d : delims) {
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
}
for (const auto & d : delimiters) {
if (i + d.tokens.size() > tokens.size()) {
continue;
}
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
matches.emplace_back(d.role, i);
break;
}
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
});
common_peg_parse_context ctx(prompt);
const auto result = parser.parse(ctx);
if (!result.success()) {
return {};
}
std::vector<common_chat_msg_span> spans;
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
if (!node.tag.empty()) {
spans.push_back({ node.tag, node.start, node.end - node.start });
}
i++;
}
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
common_chat_msg_spans spans;
for (size_t i = 0; i + 1 < matches.size(); i++) {
const auto & curr = matches[i];
const auto & next = matches[i + 1];
spans.add(curr.first, curr.second, next.second - curr.second);
}
});
return spans;
}
@@ -1133,13 +1081,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
data.prompt = prompt;
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
};
data.message_spans = common_chat_split_by_role(prompt, {
{ "assistant", "<|start|>assistant" },
{ "user", "<|start|>user" },
{ "system", "<|start|>developer" },
{ "system", "<|start|>system" },
{ "tool", "<|start|>functions" },
});
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
@@ -1280,10 +1228,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
data.prompt += data.generation_prompt;
}
data.message_delimiters = {
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
};
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "user", "<|turn>user\n" },
{ "assistant", "<|turn>model\n" },
});
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
data.supports_thinking = true;
@@ -2082,15 +2030,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
RESULT_START, RESULT_END,
};
// Declare per-role message delimiters. Tool results are rendered with the
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
};
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "assistant", GEN_PREFIX },
{ "user", TURN_START + USER },
{ "tool", TURN_START + SYSTEM + RESULT_START },
{ "system", TURN_START + SYSTEM },
});
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
@@ -2578,15 +2526,17 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
common_chat_msg_delimiters delimiters;
std::vector<common_chat_msg_delimiter> delimiters;
if (!autoparser.assistant_start.empty()) {
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
delimiters.push_back({ "assistant", autoparser.assistant_start });
}
if (!autoparser.user_start.empty()) {
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
delimiters.push_back({ "user", autoparser.user_start });
}
auto_params.message_delimiters = std::move(delimiters);
if (!delimiters.empty()) {
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
}
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
+6 -65
View File
@@ -143,75 +143,15 @@ struct common_chat_msg_diff {
}
};
enum common_chat_role {
COMMON_CHAT_ROLE_UNKNOWN,
COMMON_CHAT_ROLE_SYSTEM,
COMMON_CHAT_ROLE_ASSISTANT,
COMMON_CHAT_ROLE_USER,
COMMON_CHAT_ROLE_TOOL
};
common_chat_role common_chat_role_from_string(const std::string & role);
const char * common_chat_role_to_string(common_chat_role role);
struct common_chat_msg_span {
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::string role;
std::size_t pos = 0;
std::size_t len = 0;
bool valid() const {
return role != COMMON_CHAT_ROLE_UNKNOWN;
}
};
struct common_chat_msg_spans {
std::vector<common_chat_msg_span> spans;
void add(common_chat_role role, size_t pos, size_t len) {
spans.push_back({ role, pos, len });
}
bool is_user_start(int32_t pos) const {
for (auto it = spans.begin(); it != spans.end(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
return true;
}
}
return false;
}
int32_t last_user_message_pos() const {
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER) {
return (int32_t) it->pos;
}
}
return -1;
}
};
struct common_chat_msg_delimiter {
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::string delimiter;
llama_tokens tokens = {};
};
struct common_chat_msg_delimiters {
std::vector<common_chat_msg_delimiter> delimiters;
common_chat_msg_delimiters() = default;
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
void add(common_chat_role role, const std::string & delimiter) {
delimiters.push_back({ role, delimiter });
}
void tokenize(const llama_vocab * vocab);
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
nlohmann::ordered_json to_json() const;
std::string role;
std::string delimiter;
};
struct common_chat_tool {
@@ -279,7 +219,7 @@ struct common_chat_params {
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
std::string parser;
common_chat_msg_delimiters message_delimiters;
std::vector<common_chat_msg_span> message_spans;
};
// per-message parsing syntax
@@ -385,4 +325,5 @@ struct common_chat_prompt_preset {
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
+1 -1
View File
@@ -609,7 +609,7 @@ struct common_params {
bool cache_prompt = true; // whether to enable prompt caching
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1";
+1 -3
View File
@@ -799,7 +799,6 @@ common_download_model_result common_download_model(const common_params_model &
bool download_mmproj = opts.download_mmproj;
bool download_mtp = opts.download_mtp;
bool preset_only = opts.preset_only;
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
@@ -807,8 +806,7 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else if (!preset_only) {
// only add other files if we're NOT in preset-only mode (normal run, non-router)
} else {
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
-1
View File
@@ -55,7 +55,6 @@ struct common_download_opts {
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
bool download_mmproj = false;
bool download_mtp = false;
bool preset_only = false; // if true, only check & download remote preset (for router mode)
common_download_callback * callback = nullptr;
};
+324
View File
@@ -0,0 +1,324 @@
#include "json-partial.h"
#include "log.h"
#include <nlohmann/json.hpp>
#include <string>
#include <regex>
using json = nlohmann::ordered_json;
enum common_json_stack_element_type {
COMMON_JSON_STACK_ELEMENT_OBJECT,
COMMON_JSON_STACK_ELEMENT_KEY,
COMMON_JSON_STACK_ELEMENT_ARRAY,
};
struct common_json_stack_element {
common_json_stack_element_type type;
std::string key;
};
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out)
{
std::string::const_iterator it = input.begin();
const auto end = input.end();
return common_json_parse(it, end, healing_marker, out);
}
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out)
{
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
std::size_t position;
bool found_error;
std::string last_token;
std::string exception_message;
std::vector<common_json_stack_element> stack;
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
this->position = position - 1;
this->found_error = true;
this->last_token = last_token;
this->exception_message = ex.what();
return false;
}
void close_value() {
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
stack.pop_back();
}
}
bool null() override { // NOLINT
close_value();
return true;
}
bool boolean(bool) override { // NOLINT
close_value();
return true;
}
bool number_integer(number_integer_t) override { // NOLINT
close_value();
return true;
}
bool number_unsigned(number_unsigned_t) override { // NOLINT
close_value();
return true;
}
bool number_float(number_float_t, const string_t &) override { // NOLINT
close_value();
return true;
}
bool string(string_t &) override { // NOLINT
close_value();
return true;
}
bool binary(binary_t &) override { // NOLINT
close_value();
return true;
}
bool start_object(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
return true;
}
bool end_object() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
stack.pop_back();
close_value();
return true;
}
bool key(string_t & key) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
return true;
}
bool start_array(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
return true;
}
bool end_array() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
stack.pop_back();
close_value();
return true;
}
};
json_error_locator err_loc;
auto start = it;
json::sax_parse(it, end, &err_loc);
if (err_loc.found_error) {
it = start;
auto temptative_end = it + err_loc.position;
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
auto input = std::string(it, temptative_end);
try {
out.json = json::parse(input);
// out.json = json::parse(it, temptative_end);
it = temptative_end;
return true;
} catch (const std::exception & ex) {
// No, needs healing.
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
}
auto can_parse = [](const std::string & str) {
try {
auto _ = json::parse(str); // NOLINT
return true;
} catch (const std::exception &) {
return false;
}
};
if (!healing_marker.empty() && !err_loc.stack.empty()) {
std::string str(it, temptative_end);
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
if (last_non_sp_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
auto last_non_sp_char = str[last_non_sp_pos];
// Used to detect stops on a number, which may not be complete.
auto was_maybe_number = [&]() {
if (!str.empty() && std::isspace(str.back())) {
return false;
}
return std::isdigit(last_non_sp_char) ||
last_non_sp_char == '.' ||
last_non_sp_char == 'e' ||
last_non_sp_char == 'E' ||
last_non_sp_char == '-';
};
std::string closing;
for (size_t i = err_loc.stack.size(); i > 0; i--) {
auto & el = err_loc.stack[i - 1];
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
closing += "}";
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
closing += "]";
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
throw std::runtime_error("Unexpected stack element type");
}
}
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
auto is_high_surrogate = [&](const std::string & s) {
// Check if a partial of a high surrogate (U+D800-U+DBFF)
return s.length() >= 4 &&
s[0] == '\\' && s[1] == 'u' &&
std::tolower(s[2]) == 'd' &&
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
};
// Initialize the unicode marker to a low surrogate to handle the edge case
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
// backslash (\)
std::string unicode_marker_padding = "udc00";
std::smatch last_unicode_seq;
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
std::smatch second_last_seq;
std::string prelude = str.substr(0, last_unicode_seq.position());
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
if (is_high_surrogate(last_unicode_seq.str())) {
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
unicode_marker_padding += "\\udc00";
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
if (is_high_surrogate(second_last_seq.str())) {
// If this follows a high surrogate, pad it to be a low surrogate
if (last_unicode_seq.length() == 2) {
unicode_marker_padding = "dc00";
} else if (last_unicode_seq.length() == 3) {
unicode_marker_padding = "c00";
} else {
// The original unicode_marker_padding is already padded with 0s
}
}
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
// We're inside an object value
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
// Was about to create an object value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + ": 1" + closing)) {
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
// Was about to create an object
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an object value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an object value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
// Cutting back to opening : for object value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
// Was about to create an array value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an array value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an array value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
} else {
auto last_pos = str.find_last_of("[,");
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
}
// Cutting back to last [ or , for array value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\": 1" + closing)) {
// Was inside an object key string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
// Was inside an object key string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "Cutting back to last : for object key+value\n");
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
out.json = json::parse(str);
it = temptative_end;
return true;
}
// handle unclosed top-level primitive
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
std::string str(it, temptative_end);
const auto & magic_seed = out.healing_marker.marker = healing_marker;
if (can_parse(str + "\"")) {
// Was inside an string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
// Was inside an string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
} else {
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
// fprintf(stderr, "Closing: TODO\n");
return false;
}
out.json = json::parse(str);
it = temptative_end;
return true;
}
return false;
}
out.json = json::parse(it, end);
it = end;
return true;
}
+39
View File
@@ -0,0 +1,39 @@
#pragma once
// TODO: use json_fwd.hpp when possible
#include <nlohmann/json.hpp>
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
struct common_healing_marker {
// Raw marker.
std::string marker;
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
std::string json_dump_marker;
};
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
struct common_json {
nlohmann::ordered_json json;
common_healing_marker healing_marker;
};
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
//
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
//
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out);
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out);
-6
View File
@@ -46,7 +46,6 @@ TEXT_MODEL_MAP: dict[str, str] = {
"DbrxForCausalLM": "dbrx",
"DeciLMForCausalLM": "deci",
"DeepseekForCausalLM": "deepseek",
"DeepseekOCRForCausalLM": "deepseek",
"DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
@@ -97,7 +96,6 @@ TEXT_MODEL_MAP: dict[str, str] = {
"GraniteMoeHybridForCausalLM": "granite",
"GraniteMoeSharedForCausalLM": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"Grok1ForCausalLM": "grok",
"GrokForCausalLM": "grok",
"GroveMoeForCausalLM": "grovemoe",
@@ -125,7 +123,6 @@ TEXT_MODEL_MAP: dict[str, str] = {
"LLaDAModelLM": "llada",
"LLaMAForCausalLM": "llama",
"Lfm25AudioTokenizer": "lfm2",
"Lfm2BidirectionalModel": "lfm2",
"Lfm2ForCausalLM": "lfm2",
"Lfm2Model": "lfm2",
"Lfm2MoeForCausalLM": "lfm2",
@@ -234,7 +231,6 @@ TEXT_MODEL_MAP: dict[str, str] = {
"UMT5ForConditionalGeneration": "t5",
"UMT5Model": "t5",
"UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VLlama3ForCausalLM": "llama",
"VoxtralForConditionalGeneration": "llama",
"WavTokenizerDec": "wavtokenizer",
@@ -265,7 +261,6 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"GlmasrModel": "ultravox",
"Granite4VisionForConditionalGeneration": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"HunYuanVLForConditionalGeneration": "hunyuan",
"Idefics3ForConditionalGeneration": "smolvlm",
"InternVisionModel": "internvl",
@@ -301,7 +296,6 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"StepVLForConditionalGeneration": "step3",
"Step3p7ForConditionalGeneration": "step3",
"UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VoxtralForConditionalGeneration": "ultravox",
"YoutuVLForConditionalGeneration": "youtuvl",
}
+2 -10
View File
@@ -14,7 +14,7 @@ from .base import MmprojModel, ModelBase, TextModel, gguf, logger
from .qwen import QwenModel
@ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM")
@ModelBase.register("DeepseekOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -205,8 +205,6 @@ class DeepseekModel(TextModel):
@ModelBase.register(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekOCRForCausalLM",
"UnlimitedOCRForCausalLM",
"KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration",
"YoutuForCausalLM",
@@ -226,7 +224,7 @@ class DeepseekV2Model(TextModel):
self.origin_hf_arch = hparams.get('architectures', [None])[0]
# special handling for Deepseek OCR
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"):
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"):
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
@@ -352,12 +350,6 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
# Unlimited-OCR sliding window; written for metadata, the decoder ignores it (full MHA)
if is_ocr:
sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window")
if sliding_window:
self.gguf_writer.add_sliding_window(sliding_window)
if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
-28
View File
@@ -348,34 +348,6 @@ class GraniteSpeechMmprojModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
has_vision_encoder = False
has_audio_encoder = True
def set_gguf_parameters(self):
assert self.hparams_audio is not None
super().set_gguf_parameters()
# Add feature_layer if present in encoder config
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
self.gguf_writer.add_audio_feature_layers(feature_layers)
logger.info(f"gguf: audio feature_layers = {feature_layers}")
# Validate projector dimension matches concatenated encoder output
hidden_dim = self.hparams_audio["hidden_dim"]
expected_dim = hidden_dim * (len(feature_layers) + 1)
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
if projector_dim != expected_dim:
raise ValueError(
f"Projector encoder_hidden_size ({projector_dim}) does not match "
f"expected concatenated dimension ({expected_dim}). "
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
)
@ModelBase.register("Granite4VisionForConditionalGeneration")
class Granite4VisionMmprojModel(MmprojModel):
has_vision_encoder = True
+3 -10
View File
@@ -64,17 +64,11 @@ class LFM2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
@ModelBase.register("Lfm2Model")
class LFM2ColBertModel(LFM2Model):
model_arch = gguf.MODEL_ARCH.LFM2
dense_tensor_name = "dense_2"
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hf_arch == "Lfm2BidirectionalModel":
self.gguf_writer.add_causal_attention(False)
self._try_set_pooling_type()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith(self.dense_tensor_name):
name = "model." + name
@@ -82,11 +76,10 @@ class LFM2ColBertModel(LFM2Model):
yield from super().modify_tensors(data_torch, name, bid)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# optional dense tensor is stored in a separate safetensors file
# dense tensor is stored in a separate safetensors file
from safetensors.torch import load_file
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
if not tensors_file.is_file():
return
assert tensors_file.is_file()
tensor = load_file(tensors_file)["linear.weight"]
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
yield f"{self.dense_tensor_name}.weight", tensor.clone()
@@ -24,6 +24,7 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_OPENSSL": "OFF"
}
},
@@ -46,6 +47,7 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_OPENSSL": "OFF"
}
},
@@ -71,6 +73,7 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "OFF",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_OPENSSL": "OFF"
}
},
+1 -3
View File
@@ -266,6 +266,7 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
"ggml: OpenCL API version to target")
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
# toolchain for vulkan-shaders-gen
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
@@ -340,9 +341,6 @@ set(GGML_PUBLIC_HEADERS
include/gguf.h)
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
#if (GGML_METAL)
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
#endif()
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
install(TARGETS ggml-base LIBRARY)
+23 -50
View File
@@ -3688,6 +3688,8 @@ static void ggml_compute_forward_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -3701,49 +3703,25 @@ static void ggml_compute_forward_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
const float * xf = (const float *) x;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, x);
float mean = sum/ne00;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, xf);
float mean = sum/ne00;
float * yf = (float *) y;
float variance = 0;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
float variance = 0;
#ifdef GGML_USE_ACCELERATE
mean = -mean;
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
vDSP_measqv(yf, 1, &variance, ne00);
mean = -mean;
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
vDSP_measqv(y, 1, &variance, ne00);
#else
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
#endif //GGML_USE_ACCELERATE
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, yf, scale);
} else {
float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += *(const float *) (x + i00*nb00);
}
const float mean = sum/ne00;
float variance = 0.0f;
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float v = *(const float *) (x + i00*nb00) - mean;
*(float *) (y + i00*nb0) = v;
variance += v * v;
}
variance /= ne00;
const float scale = 1.0f/sqrtf(variance + eps);
for (int64_t i00 = 0; i00 < ne00; i00++) {
*(float *) (y + i00*nb0) *= scale;
}
}
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
@@ -4164,6 +4142,8 @@ static void ggml_compute_forward_l2_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -4178,27 +4158,20 @@ static void ggml_compute_forward_l2_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float xi = *(const float *) (x + i00*nb00);
sum += (ggml_float)(xi * xi);
sum += (ggml_float)(x[i00] * x[i00]);
}
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
memcpy(y, x, ne00 * sizeof(float));
ggml_vec_scale_f32(ne00, (float *) y, scale);
} else {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float xi = *(const float *) (x + i00*nb00);
*(float *) (y + i00*nb0) = xi * scale;
}
}
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
+1 -1
View File
@@ -5334,7 +5334,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous_rows(op->src[0]);
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]);
break;
+4
View File
@@ -25,6 +25,7 @@ include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
add_library(htp_iface OBJECT
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
@@ -71,12 +72,15 @@ function(build_htp_skel V)
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
-DDSP_VERSION=${V}
-DPREBUILT_LIB_DIR="toolv19_${V}")
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
endfunction()
build_htp_skel(v68)
build_htp_skel(v69)
build_htp_skel(v73)
build_htp_skel(v75)
build_htp_skel(v79)
File diff suppressed because it is too large Load Diff
+50 -156
View File
@@ -5,12 +5,10 @@
#include "ggml-backend-impl.h"
#include "ggml-common.h"
#include <algorithm>
#include <string>
#include <vector>
#include <stdio.h>
#include "htp-ops.h"
#include "htp/matmul-ops.h"
struct htp_opnode {
ggml_tensor * node = nullptr;
@@ -19,13 +17,6 @@ struct htp_opnode {
htp_op_code opcode = HTP_OP_INVALID;
std::vector<ggml_tensor *> extra_dsts;
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS] = {0};
htp_opnode(ggml_tensor * node = nullptr, std::vector<ggml_tensor *> fused = {}, htp_op_code opcode = HTP_OP_INVALID, std::vector<ggml_tensor *> extra_dsts = {})
: node(node), fused(std::move(fused)), opcode(opcode), extra_dsts(std::move(extra_dsts)) {}
ggml_op op() const {
return node->op;
}
@@ -34,26 +25,6 @@ struct htp_opnode {
return fused.empty() ? node : fused.back();
}
void add_fused(ggml_tensor * t, bool extra_dst = false) {
fused.push_back(t);
if (extra_dst) {
extra_dsts.push_back(t);
}
}
std::vector<const ggml_tensor *> get_outputs() const {
std::vector<const ggml_tensor *> res;
if (extra_dsts.empty()) {
res.push_back(dst());
} else {
res.push_back(node);
for (const auto * x : extra_dsts) {
res.push_back(x);
}
}
return res;
}
const ggml_tensor * src0() const {
return node->src[0];
}
@@ -66,6 +37,10 @@ struct htp_opnode {
return ggml_op_is_empty(node->op);
}
void add_fused(ggml_tensor * t) {
fused.push_back(t);
}
bool stackable() const {
switch (this->op()) {
case GGML_OP_MUL_MAT:
@@ -156,117 +131,87 @@ struct htp_opformat {
char types[16 * GGML_MAX_SRC];
char buffs[64 * GGML_MAX_SRC];
char names[64 * GGML_MAX_SRC];
char kparams[128];
int format_tensor_dims(char * str, size_t max_size, const struct ggml_tensor * t) {
int format_tensor_dims(char * str, const struct ggml_tensor * t) {
if (!t) {
return snprintf(str, max_size, "NONE");
return sprintf(str, "NONE");
}
if (t->ne[2] == 1 && t->ne[3] == 1) {
return snprintf(str, max_size, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
} else {
return snprintf(str, max_size, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
}
}
void format_op_dims(char * str, size_t max_size, const htp_opnode & node) {
void format_op_dims(char * str, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[0]), (size_t)(p_end - p));
p += format_tensor_dims(p, inputs[0]);
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[i]), (size_t)(p_end - p));
}
p += sprintf(p, " x ");
p += format_tensor_dims(p, inputs[i]);
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
}
char self[64];
format_tensor_dims(self, sizeof(self), node.dst());
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
}
format_tensor_dims(self, node.dst());
p += sprintf(p, "%s", self);
}
int format_tensor_strides(char * str, size_t max_size, const struct ggml_tensor * t) {
int format_tensor_strides(char * str, const struct ggml_tensor * t) {
if (!t) {
return snprintf(str, max_size, "NONE");
return sprintf(str, "NONE");
}
const char * c = ggml_is_contiguous(t) ? "" : "!";
if (t->ne[2] == 1 && t->ne[3] == 1) {
return snprintf(str, max_size, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
} else {
return snprintf(str, max_size, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
}
}
void format_op_strides(char * str, size_t max_size, const htp_opnode & node) {
void format_op_strides(char * str, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[0]), (size_t)(p_end - p));
p += format_tensor_strides(p, inputs[0]);
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[i]), (size_t)(p_end - p));
}
p += sprintf(p, " x ");
p += format_tensor_strides(p, inputs[i]);
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
}
char self[64];
format_tensor_strides(self, sizeof(self), node.dst());
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
}
format_tensor_strides(self, node.dst());
p += sprintf(p, "%s", self);
}
void format_op_types(char * str, size_t max_size, const htp_opnode & node) {
void format_op_types(char * str, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"), (size_t)(p_end - p));
}
p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE");
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"), (size_t)(p_end - p));
}
p += sprintf(p, " x ");
p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE");
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", ggml_type_name(node.dst()->type)), (size_t)(p_end - p));
}
p += sprintf(p, "%s", ggml_type_name(node.dst()->type));
}
const char * tensor_buff_name(const struct ggml_tensor * t) {
@@ -276,102 +221,51 @@ struct htp_opformat {
return "NONE";
}
void format_op_buffs(char * str, size_t max_size, const htp_opnode & node) {
void format_op_buffs(char * str, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[0])), (size_t)(p_end - p));
}
p += sprintf(p, "%s", tensor_buff_name(inputs[0]));
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[i])), (size_t)(p_end - p));
}
p += sprintf(p, " x ");
p += sprintf(p, "%s", tensor_buff_name(inputs[i]));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(node.dst())), (size_t)(p_end - p));
}
p += sprintf(p, "%s", tensor_buff_name(node.dst()));
}
void format_op_names(char * str, size_t max_size, const htp_opnode & node) {
void format_op_names(char * str, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? inputs[0]->name : "NONE"), (size_t)(p_end - p));
}
p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE");
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? inputs[i]->name : "NONE"), (size_t)(p_end - p));
}
p += sprintf(p, " x ");
p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE");
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", node.dst()->name), (size_t)(p_end - p));
}
}
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
if (type == HTP_MM_KERNEL_HMX_2D || type == HTP_MM_KERNEL_HMX_F16_BATCHED) {
path = "hmx-tiled";
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || type == HTP_MM_KERNEL_HVX_F32_F32_VTCM ||
type == HTP_MM_KERNEL_HVX_QUANT_ROW || type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) {
path = "hvx-tiled";
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_DDR || type == HTP_MM_KERNEL_HVX_F16_F32_DDR ||
type == HTP_MM_KERNEL_HVX_F32_F32_DDR || type == HTP_MM_KERNEL_HVX_F32_F16_DDR ||
type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) {
path = "hvx-flat";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else {
snprintf(str, max_size, "----");
}
p += sprintf(p, "%s", node.dst()->name);
}
void format(const htp_opnode & node) {
format_op_dims(dims, sizeof(dims), node);
format_op_strides(strides, sizeof(strides), node);
format_op_types(types, sizeof(types), node);
format_op_buffs(buffs, sizeof(buffs), node);
format_op_names(names, sizeof(names), node);
format_kernel_params(kparams, sizeof(kparams), node);
format_op_dims(dims, node);
format_op_strides(strides, node);
format_op_types(types, node);
format_op_buffs(buffs, node);
format_op_names(names, node);
}
htp_opformat() {
strides[0] = '\0';
dims[0] = '\0';
types[0] = '\0';
buffs[0] = '\0';
names[0] = '\0';
kparams[0] = '\0';
}
htp_opformat() {}
htp_opformat(const htp_opnode & node) { format(node); }
};
+38 -14
View File
@@ -19,9 +19,43 @@ add_library(${HTP_LIB} SHARED
htp_iface_skel.c
worker-pool.c
hex-dma.c
hmx-queue.c
flash-attn-ops.c
hmx-flash-attn-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
target_sources(${HTP_LIB} PRIVATE
matmul-ops.c
binary-ops.c
unary-ops.c
@@ -29,6 +63,7 @@ add_library(${HTP_LIB} SHARED
softmax-ops.c
act-ops.c
rope-ops.c
flash-attn-ops.c
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
@@ -44,17 +79,6 @@ add_library(${HTP_LIB} SHARED
pad-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
install(TARGETS ${HTP_LIB})
+15 -13
View File
@@ -3,7 +3,7 @@ if (HEXAGON_TOOLCHAIN_INCLUDED)
endif()
set(HEXAGON_TOOLCHAIN_INCLUDED true)
# Cross Compiling for Hexagon
#Cross Compiling for Hexagon
set(HEXAGON TRUE)
set(CMAKE_SYSTEM_NAME QURT)
set(CMAKE_SYSTEM_PROCESSOR Hexagon)
@@ -14,6 +14,7 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
set(CUSTOM_RUNELF_PATH "")
#To fix backward compatibility with EAI addon.
if (NOT HEXAGON_SDK_ROOT)
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
endif()
@@ -30,6 +31,7 @@ endif()
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
#Get the Binary extension of the Hexagon Toolchain
if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
endif()
@@ -46,12 +48,12 @@ set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
HEXAGON_TOOLS_ROOT
)
# QURT Related includes and linker flags
#QURT Related includes and linker flags
set(V_ARCH ${HEXAGON_ARCH})
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
if (${TREE} MATCHES PAKMAN)
if( ${TREE} MATCHES PAKMAN )
set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
endif()
message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
@@ -81,9 +83,11 @@ set(QURT_START_LINK_LIBS
)
STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
set(QURT_END_LINK_LIBS ${TARGET_DIR}/fini.o)
set(QURT_END_LINK_LIBS
${TARGET_DIR}/fini.o
)
# Non QURT related includes and linker flags
#Non QURT related includes and linker flags
set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
@@ -95,10 +99,8 @@ if (NOT NO_WRAP_MEM_API)
set(WRAP_MEMALIGN -Wl,--wrap=memalign)
endif()
set(ARCH_FLAGS "-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -mhmx")
set(PIC_SHARED_LD_FLAGS
${ARCH_FLAGS}
-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
-G0
-fpic
-Wl,-Bsymbolic
@@ -118,13 +120,13 @@ STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
# System include paths
#System include paths
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
# LLVM toolchain setup
# Compiler paths, options and architecture
#LLVM toolchain setup
#Compiler paths, options and architecture
set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
@@ -135,8 +137,8 @@ set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
# Compiler Options
set(COMMON_FLAGS "${ARCH_FLAGS} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
#Compiler Options
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
+3 -2
View File
@@ -18,8 +18,7 @@
#include "htp-ctx.h"
#include "htp-ops.h"
#include "htp-ops.h"
int hmx_flash_attn_ext(struct htp_ops_context * octx);
#include "hmx-ops.h"
// Must be multiple of 32
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
@@ -634,6 +633,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
#ifdef HTP_HAS_HMX
// HMX path: head_dim multiple of 64, F16 KV, and no sinks
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) {
int ret = hmx_flash_attn_ext(octx);
@@ -642,6 +642,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
}
// VTCM too small or other failure -> fall through to HVX path
}
#endif
struct htp_fa_context factx;
factx.octx = octx;
-80
View File
@@ -1,80 +0,0 @@
#ifndef HEX_COMMON_H
#define HEX_COMMON_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#ifndef SIZE_MAX
#define SIZE_MAX ((size_t)-1)
#endif
#ifndef MAX
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
static inline uint32_t hex_ceil_pow2(uint32_t x) {
if (x <= 1) { return 1; }
int p = 2;
x--;
while (x >>= 1) { p <<= 1; }
return p;
}
static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_swap_ptr(void ** p1, void ** p2) {
void * t = *p1;
*p1 = *p2;
*p2 = t;
}
static inline bool hex_mul_overflow(size_t a, size_t b, size_t *out) {
if (a != 0 && b > SIZE_MAX / a) return true;
*out = a * b;
return false;
}
static inline bool hex_add_overflow(size_t a, size_t b, size_t *out) {
if (a > SIZE_MAX - b) return true;
*out = a + b;
return false;
}
#endif // HEX_COMMON_H
+5 -1
View File
@@ -5,7 +5,6 @@
#include <hexagon_types.h>
#include <stdbool.h>
#include <stdint.h>
#include "hex-utils.h"
#include "hex-profile.h"
@@ -128,8 +127,13 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src)
return p;
}
#if __HVX_ARCH__ < 73
static const uint32_t dma_src_l2_bypass_on = 1;
static const uint32_t dma_dst_l2_bypass_on = 0;
#else
static const uint32_t dma_src_l2_bypass_on = 1;
static const uint32_t dma_dst_l2_bypass_on = 1;
#endif
static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) {
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
+56 -1
View File
@@ -11,7 +11,14 @@
#include "hex-fastdiv.h"
#include "hex-dump.h"
#include "hex-common.h"
#ifndef MAX
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
static inline uint64_t hex_get_cycles() {
uint64_t cycles = 0;
@@ -25,6 +32,54 @@ static inline uint64_t hex_get_pktcnt() {
return pktcnt;
}
static inline uint32_t hex_ceil_pow2(uint32_t x) {
if (x <= 1) { return 1; }
int p = 2;
x--;
while (x >>= 1) { p <<= 1; }
return p;
}
static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_swap_ptr(void ** p1, void ** p2) {
void * t = *p1;
*p1 = *p2;
*p2 = t;
}
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
Q6_l2fetch_AP((void *) p, control);
+13 -13
View File
@@ -49,7 +49,7 @@
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) {
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
@@ -70,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
+ k_dma_size * 2 // K DMA x2
+ v_dma_size * 2 // V DMA x2
+ k_tile_size * 1 // K tiles
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ s_tile_size * 2 // S + P
+ d_tile_size * 1 // D (diagonal matrix)
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
@@ -290,7 +290,7 @@ static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) =
struct hmx_fa_context {
const struct htp_ops_context * octx;
bool pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
uint32_t n_threads;
// Op parameters
@@ -409,7 +409,7 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
return;
}
__fp16 * v_tiles_dest = factx->pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
@@ -1312,13 +1312,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS);
const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc;
const bool pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
// Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1
const uint32_t n_threads = pipeline ? n_threads_init : 1;
const uint32_t n_threads = use_pipeline ? n_threads_init : 1;
FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu",
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, pipeline, vtcm_budget);
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget);
// ======== Build context ========
struct hmx_fa_context factx;
@@ -1339,7 +1339,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
factx.n_kv_blocks = n_kv_blocks;
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32);
factx.pipeline = pipeline;
factx.use_pipeline = use_pipeline;
factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1);
// Extract op parameters (mutable during softcap adjustment, then stored as const in factx)
@@ -1405,7 +1405,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes);
factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
if (pipeline) {
if (use_pipeline) {
factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
} else {
factx.vtcm_v_tiles[1] = NULL;
@@ -1456,7 +1456,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
// ======== HMX lock strategy ========
// Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend.
// Fallback: main thread holds the lock (original behavior).
if (!factx.pipeline) {
if (!factx.use_pipeline) {
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
}
@@ -1550,7 +1550,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t k_src_stride = size_k_row_padded / sizeof(__fp16);
const size_t v_src_stride = size_v_row_padded / sizeof(__fp16);
if (factx.pipeline) {
if (factx.use_pipeline) {
// ==================================================================
// Pipeline path: HVX phases ‖ HMX queue worker
// ==================================================================
@@ -1780,7 +1780,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
// HMX: O_final = diag(1/l) @ O_prev
if (factx.pipeline) {
if (factx.use_pipeline) {
on_job.o_curr = o_tile_curr;
on_job.o_prev = o_tile_prev;
on_job.d_tiles = factx.vtcm_d_tiles;
@@ -1826,7 +1826,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
} // end KV head loop
} // end batch loop
if (factx.pipeline) {
if (factx.use_pipeline) {
hmx_queue_suspend(ctx->hmx_queue);
} else {
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+6
View File
@@ -0,0 +1,6 @@
// HMX operations compiled as a single translation unit.
// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO.
#include "hmx-queue.c"
#include "hmx-matmul-ops.c"
#include "hmx-flash-attn-ops.c"
+88
View File
@@ -0,0 +1,88 @@
// HMX operation entry-point declarations.
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_OPS_H
#define HMX_OPS_H
#include <stddef.h>
#include <stdint.h>
#include "htp-ops.h"
#ifdef __cplusplus
extern "C" {
#endif
typedef struct {
float *dst;
const float *activation;
const __fp16 *permuted_weight;
int m;
int k;
int n;
int act_stride;
int weight_stride;
int dst_stride;
int ne02;
int ne03;
int ne12;
int ne13;
size_t src0_nb2;
size_t src0_nb3;
size_t src1_nb2;
size_t src1_nb3;
size_t dst_nb2;
size_t dst_nb3;
} hmx_matmul_f16_f32_batched_params_t;
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
// act_stride: activation row stride in elements (= k for contiguous, or
// nb[1]/sizeof(float) for permuted tensors like attention Q).
// weight_stride: weight row stride in elements (= k for compact weights, or
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
int hmx_matmul_f16_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const __fp16 *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride);
// Batched F16 wrapper over hmx_mat_mul_f16_f32.
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);
// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4)
int hmx_matmul_2d_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride,
int weight_type);
struct mmid_row_mapping;
int hmx_matmul_id_2d_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int ne11,
size_t act_nb1, size_t act_nb2,
size_t dst_nb1, size_t dst_nb2,
int weight_stride,
int weight_type,
const struct mmid_row_mapping *matrix_rows,
int cur_a,
int mapping_stride);
// HMX flash attention
int hmx_flash_attn_ext(struct htp_ops_context * octx);
#ifdef __cplusplus
}
#endif
#endif // HMX_OPS_H
+3 -9
View File
@@ -13,9 +13,7 @@
#include <stdint.h>
#include <stdbool.h>
#ifndef HTP_MAX_NTHREADS
#define HTP_MAX_NTHREADS 10
#endif
#define HTP_MAX_MMAPS 16
// Memory mapping
@@ -44,13 +42,9 @@ struct htp_ops_context {
enum htp_op_code op; // FIXME: rename to opcode
int32_t op_params[HTP_OP_MAX_PARAMS];
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS];
const struct htp_tensor * src[HTP_OP_MAX_INPUTS];
union {
const struct htp_tensor * dst;
const struct htp_tensor * dsts[HTP_OP_MAX_OUTPUTS];
};
const struct htp_tensor * dst;
// TODO convert these to an array
struct htp_spad src0_spad;
@@ -93,13 +87,13 @@ struct htp_context {
struct htp_ops_context octx;
#ifdef HTP_HAS_HMX
struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap
#endif
};
int op_matmul(struct htp_ops_context * octx);
int op_matmul_id(struct htp_ops_context * octx);
int op_matmul_qkv(struct htp_ops_context * octx);
int op_matmul_ffn(struct htp_ops_context * octx);
int op_binary(struct htp_ops_context * octx);
int op_unary(struct htp_ops_context * octx);
int op_sum_rows(struct htp_ops_context * octx);
+8 -15
View File
@@ -28,19 +28,18 @@ enum htp_data_type {
HTP_TYPE_MXFP4 = 39,
// types used internally for repack, dyn.quant, etc
HTP_TYPE_Q4_0_TILED = 200,
HTP_TYPE_Q4_1_TILED,
HTP_TYPE_Q8_0_TILED,
HTP_TYPE_MXFP4_TILED,
HTP_TYPE_Q4_0x4x2 = 200,
HTP_TYPE_Q4_1x4x2,
HTP_TYPE_Q8_0x4x2,
HTP_TYPE_MXFP4x4x2,
HTP_TYPE_INVALID
};
// Constats for internal types
#define QK_Q4_0_TILED 256 // 32x32 Q4_0 tiled layout
#define QK_Q8_0_TILED 128 // 32x32 Q8_0 tiled layout
#define QK_MXFP4_TILED 256 // 32x32 MXFP4 tiled layout
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
// Mask to enable various stages of the Ops.
@@ -58,8 +57,6 @@ enum htp_op_code {
HTP_OP_DIV = 3,
HTP_OP_MUL_MAT,
HTP_OP_MUL_MAT_ID,
HTP_OP_MUL_MAT_QKV,
HTP_OP_MUL_MAT_FFN,
HTP_OP_RMS_NORM,
HTP_OP_RMS_NORM_MUL,
HTP_OP_UNARY_SILU,
@@ -102,9 +99,7 @@ enum htp_op_code {
#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
#define HTP_OP_MAX_OUTPUTS 4
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
#define HTP_OP_MAX_KERN_PARAMS 32
#define HTP_OP_MAX_BUFS 16
#define HTP_OP_MAX_REQS 256
@@ -147,10 +142,8 @@ struct htp_op_desc {
uint32_t opcode; // GGML/HTP Op
uint32_t flags; // Op flags
int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; // generic blob for host-precomputed parameters
uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices
uint16_t dst[HTP_OP_MAX_OUTPUTS]; // Output tensor indices
uint16_t pad[2]; // padding to align to 64 bits
uint16_t dst; // Output tensor index
};
#ifndef HTP_MAX_NTHREADS
+1 -2
View File
@@ -11,13 +11,12 @@ struct htp_iface_pmu_conf {
};
interface htp_iface : remote_handle64 {
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 n_hmx, in uint64 max_vmem);
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem);
AEEResult stop();
AEEResult mmap(in uint32 fd, in uint32 size);
AEEResult munmap(in uint32 fd);
AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu);
AEEResult etm(in uint32 enable);
AEEResult hwinfo(rout uint32 n_threads, rout uint32 n_hvx, rout uint32 n_hmx, rout uint64 vtcm_size);
};
#endif /* HTP_IDL */
+18 -13
View File
@@ -170,7 +170,25 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
}
#endif
/* Q6_Vsf_equals_Vw is only available on v73+.*/
#if __HVX_ARCH__ < 73
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
{
HVX_Vector const vzero = Q6_V_vzero();
HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
return ret;
}
static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
{
return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
}
#endif
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
// This looks complicated.
@@ -287,17 +305,4 @@ static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
#endif // __HVX_ARCH__ < 79
static inline HVX_Vector hvx_vec_load_act_tile(const uint8_t * y_q, uint32_t kt, HVX_Vector * v_act_all) {
if (kt % 4 == 0) {
*v_act_all = hvx_vmem(y_q + kt * 32);
return *v_act_all;
} else if (kt % 4 == 1) {
return Q6_V_vror_VR(*v_act_all, 32);
} else if (kt % 4 == 2) {
return Q6_V_vror_VR(*v_act_all, 64);
} else {
return Q6_V_vror_VR(*v_act_all, 96);
}
}
#endif /* HVX_BASE_H */
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+23 -81
View File
@@ -361,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) {
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
static void htp_error_callback(dspqueue_t queue, int error, void * context);
AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp_queue_id, uint32_t n_hvx, uint32_t n_hmx, uint64_t max_vmem) {
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) {
struct htp_context * ctx = (struct htp_context *) handle;
if (!ctx) {
@@ -395,9 +395,10 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp
return AEE_ENOMEMORY;
}
ctx->hmx_enabled = n_hmx;
#ifdef HTP_HAS_HMX
ctx->hmx_enabled = use_hmx;
ctx->hmx_queue = NULL;
if (n_hmx) {
if (use_hmx) {
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
if (ctx->hmx_queue) {
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
@@ -406,7 +407,8 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp
ctx->hmx_enabled = false;
}
}
FARF(HIGH, "HMX %s (n_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", n_hmx);
FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx);
#endif
qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads);
@@ -479,11 +481,13 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
dma_queue_delete(ctx->dma[i]);
}
#ifdef HTP_HAS_HMX
if (ctx->hmx_queue) {
hmx_queue_delete(ctx->hmx_queue);
ctx->hmx_queue = NULL;
}
ctx->hmx_enabled = false;
#endif
vtcm_free(ctx);
@@ -496,36 +500,6 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
return AEE_SUCCESS;
}
AEEResult htp_iface_hwinfo(remote_handle64 handle, uint32_t * n_threads, uint32_t * n_hvx, uint32_t * n_hmx, uint64_t * vtcm_size) {
(void)handle;
if (!n_threads || !n_hvx || !n_hmx || !vtcm_size) {
return AEE_EBADPARM;
}
qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads);
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
uint32_t n_hvx_val = hw_nhvx;
if (n_hvx_val > hw_threads.max_hthreads) {
n_hvx_val = hw_threads.max_hthreads;
}
if (n_hvx_val > HTP_MAX_NTHREADS) {
n_hvx_val = HTP_MAX_NTHREADS;
}
// for now we force n_threads == n_hvx
*n_threads = n_hvx_val;
*n_hvx = n_hvx_val;
*n_hmx = 1;
uint32_t vtcm_sz = 8 * 1024 * 1024; // 8MB default fallback
HAP_compute_res_query_VTCM(0, (unsigned int *)&vtcm_sz, NULL, NULL, NULL);
*vtcm_size = vtcm_sz;
return AEE_SUCCESS;
}
static void htp_error_callback(dspqueue_t queue, int error, void * context) {
// No errors expected on the DSP.
FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
@@ -580,12 +554,6 @@ static int execute_op(struct htp_ops_context * octx) {
case HTP_OP_MUL_MAT_ID:
return op_matmul_id(octx);
case HTP_OP_MUL_MAT_QKV:
return op_matmul_qkv(octx);
case HTP_OP_MUL_MAT_FFN:
return op_matmul_ffn(octx);
case HTP_OP_MUL:
case HTP_OP_ADD:
case HTP_OP_SUB:
@@ -794,9 +762,8 @@ static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, str
}
}
static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
memcpy(octx->op_params, op->params, sizeof(octx->op_params));
memcpy(octx->kernel_params, op->kernel_params, sizeof(octx->kernel_params));
octx->flags = op->flags;
octx->op = op->opcode;
@@ -818,41 +785,22 @@ static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, u
src->ne[0], src->ne[1], src->ne[3], src->ne[3]);
}
// Prep output tensors
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
uint16_t dst_idx = op->dst[i];
if (dst_idx == 0xffff) {
octx->dsts[i] = NULL;
continue;
}
struct htp_tensor *dst = tens + dst_idx;
octx->dsts[i] = dst;
// Prep output tensor
struct htp_tensor *dst = tens + op->dst;
FARF(HIGH, "prep-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, dst_idx, (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
}
octx->dst = dst;
int status = execute_op(octx);
FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
octx->src0_spad.src = NULL;
octx->src1_spad.src = NULL;
octx->src2_spad.src = NULL;
octx->src3_spad.src = NULL;
octx->dst_spad.src = NULL;
(void) execute_op(octx);
// flush buffers on output
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
if (octx->dsts[i]) {
struct htp_tensor *dst = (struct htp_tensor *)octx->dsts[i];
hex_l2flush((void *) dst->data, dst->size);
dst->flags |= HTP_TENSOR_FLUSHED;
hex_l2flush((void *) dst->data, dst->size);
dst->flags |= HTP_TENSOR_FLUSHED;
FARF(HIGH, "post-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, op->dst[i], (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
}
}
return status;
FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
}
#define DSPQUEUE_POLL_TIMEOUT_USEC 100
@@ -944,26 +892,20 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
}
}
int op_status = HTP_STATUS_OK;
uint32_t op_wakeup = n_ops / 2; // half-way throgh the batch
for (uint32_t i=0; i < n_ops; i++) {
struct profile_data prof;
if (i == op_wakeup) {
if (i == (n_ops-1)) {
// wake up the host before starting the last op
dspqueue_write_early_wakeup_noblock(queue, 0, 0);
}
profile_start(ctx->profiler, &prof);
op_status = proc_op_req(octx, tens, i, &ops[i]);
proc_op_req(octx, tens, i, &ops[i]);
profile_stop(ctx->profiler, &prof);
if (op_status != HTP_STATUS_OK) {
break;
}
if (ctx->profiler) {
pds[i].opcode = ops[i].opcode;
pds[i].usecs = prof.usecs;
@@ -977,7 +919,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_opbatch_rsp rsp;
rsp.id = req.id;
rsp.status = op_status;
rsp.status = HTP_STATUS_OK;
rsp.n_bufs = n_bufs;
rsp.n_tensors = n_tens;
rsp.n_ops = n_ops;
File diff suppressed because it is too large Load Diff
-508
View File
@@ -1,508 +0,0 @@
#ifndef HTP_MATMUL_OPS_H
#define HTP_MATMUL_OPS_H
#include <stdint.h>
#include <stddef.h>
#include "htp-ops.h"
#include "hex-fastdiv.h"
#include "hex-common.h"
#ifdef __cplusplus
extern "C" {
#endif
// --- HMX Tile Constraints ---
#define HTP_MM_HMX_TILE_N_COLS 32
#define HTP_MM_HMX_TILE_N_ROWS 32
#define HTP_MM_HMX_TILE_SIZE (32 * 32 * sizeof(__fp16)) // 2048 bytes
#define HTP_MM_HMX_TILE_N_ELMS 1024
#define HTP_MM_HMX_MIN_NROWS 4
// --- Weight Repacked Tile Sizes ---
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_0 576
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_1 640
#define HTP_MM_WEIGHT_TILE_SIZE_Q8_0 1088
#define HTP_MM_WEIGHT_TILE_SIZE_IQ4_NL 576
#define HTP_MM_WEIGHT_TILE_SIZE_MXFP4 544
// --- Weight Repacked Aligned Tile Sizes ---
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0 640
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1 640
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0 1152
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_IQ4_NL 640
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4 640
// --- Activation Tiled Block Sizes (including padding) ---
#define HTP_MM_ACT_TILE_SIZE_Q8_0 1152
#define HTP_MM_ACT_TILE_SIZE_Q8_1 1280
#define HTP_MM_MAX_PREFETCH 16
// --- Solver Cost Model Penalty Weights (HMX-specific) ---
#define HTP_MM_HMX_COST_W_DEQUANT 3 // cost penalty for quantized weight loading/dequantization
#define HTP_MM_HMX_COST_A_CONVERT 2 // cost penalty for activation loading/conversion
// --- DMA Activation Transfer Configuration ---
#define HTP_MM_DMA_ACT_ROWS_PER_STEP 2
#define HTP_MM_DMA_ACT_MULTIPLIER 4
enum htp_mm_kernel_type {
HTP_MM_KERNEL_UNSUPPORTED = 0,
// HMX paths
HTP_MM_KERNEL_HMX_2D,
HTP_MM_KERNEL_HMX_F16_BATCHED,
// HVX floating-point paths
HTP_MM_KERNEL_HVX_F16_F16_VTCM,
HTP_MM_KERNEL_HVX_F16_F16_DDR,
HTP_MM_KERNEL_HVX_F16_F32_DDR,
HTP_MM_KERNEL_HVX_F32_F32_VTCM,
HTP_MM_KERNEL_HVX_F32_F32_DDR,
HTP_MM_KERNEL_HVX_F32_F16_DDR,
// HVX quantized paths
HTP_MM_KERNEL_HVX_QUANT_ROW, // standard row-wise parallel quantization
HTP_MM_KERNEL_HVX_QUANT_BLOCK, // parallel block-wise quantization
HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT, // row-wise fallback flat quantization
};
// Op-specific struct for precomputed matmul params
struct htp_mm_kernel_params {
int32_t kernel_type; // enum htp_mm_kernel_type
int32_t pipeline; // 1 = pipelined execution, 0 = standard
int32_t m_chunk; // Row chunk size (M chunk)
int32_t n_chunk; // Col chunk size (N chunk)
int32_t n_threads; // Number of threads to spawn
int32_t n_act_threads; // Number of threads for activation preparation
int32_t n_hmx; // 1 = use HMX, 0 = use HVX
int32_t n_prefetch; // Prefetch lookahead buffers/rows in VTCM
int32_t tile_size; // Weight tile size
int32_t aligned_tile_size; // Aligned weight tile size (padded to 128)
int32_t src1_row_size; // Row size for quantized activation
int32_t vtcm_size; // Total required scratchpad size in VTCM
int32_t vtcm_src0_size; // src0 scratchpad size in VTCM
int32_t vtcm_src1_size; // src1 scratchpad size in VTCM
int32_t vtcm_src2_size; // src2 scratchpad size in VTCM (fused only)
int32_t vtcm_src3_size; // src3 scratchpad size in VTCM (fused only)
int32_t vtcm_dst_size; // dst scratchpad size in VTCM
// Precomputed division values
struct fastdiv_values div_ne12_ne1;
struct fastdiv_values div_ne1;
struct fastdiv_values div_r2;
struct fastdiv_values div_r3;
struct fastdiv_values div_ne11;
};
#if defined(__cplusplus)
static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
#else
_Static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
#endif
struct mmid_row_mapping {
uint32_t i1;
uint32_t i2;
};
// Search for optimal (mc, nc) chunk sizes within VTCM budget.
static inline int htp_mm_hmx_compute_chunks(size_t vtcm_total,
size_t overhead,
size_t per_n_cost,
size_t per_m_cost,
size_t per_mn_cost,
size_t m,
size_t n,
size_t m_block_cost,
size_t n_block_cost,
size_t * m_chunk_out,
size_t * n_chunk_out,
size_t * total_out) {
if (m == 0 || n == 0) return -1;
if (vtcm_total <= overhead) return -1;
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
const size_t usable = vtcm_total - overhead;
size_t best_cost = SIZE_MAX;
size_t best_mn = 0;
size_t best_m = 0, best_n = 0;
const size_t n_max = hex_align_down((size_t)n, HTP_MM_HMX_TILE_N_COLS);
for (size_t nc = n_max; nc >= HTP_MM_HMX_TILE_N_COLS; nc -= HTP_MM_HMX_TILE_N_COLS) {
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
if (hex_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
if (n_fixed >= usable) goto next_nc;
if (hex_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc;
if (hex_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc;
{
size_t remain = usable - n_fixed;
size_t mc = remain / mc_denom;
mc = hex_align_down(mc, HTP_MM_HMX_TILE_N_ROWS);
mc = hex_smin(mc, m);
if (mc == 0) {
goto next_nc;
}
size_t mblocks = ((size_t) m + mc - 1) / mc;
size_t nblocks = ((size_t) n + nc - 1) / nc;
size_t cost = mblocks * m_block_cost + nblocks * n_block_cost;
size_t mn = mc * nc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_m = mc;
best_n = nc;
}
}
next_nc:
if (nc == HTP_MM_HMX_TILE_N_COLS) break; // avoid size_t underflow
}
if (best_m == 0 || best_n == 0) return -1;
// Compute exact total (with overflow checks)
size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0;
if (hex_mul_overflow(best_n, per_n_cost, &t0)) return -1;
if (hex_mul_overflow(best_m, per_m_cost, &t1)) return -1;
if (hex_mul_overflow(best_m, best_n, &mn)) return -1;
if (hex_mul_overflow(mn, per_mn_cost, &t2)) return -1;
if (hex_add_overflow(t0, t1, &total)) return -1;
if (hex_add_overflow(total, t2, &total)) return -1;
if (hex_add_overflow(total, overhead, &total)) return -1;
*m_chunk_out = best_m;
*n_chunk_out = best_n;
*total_out = total;
return 0;
}
// --- Tile Size Helpers ---
static inline uint32_t htp_mm_get_weight_tile_size(int weight_type) {
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
return HTP_MM_WEIGHT_TILE_SIZE_Q4_0;
case HTP_TYPE_Q4_1:
return HTP_MM_WEIGHT_TILE_SIZE_Q4_1;
case HTP_TYPE_Q8_0:
return HTP_MM_WEIGHT_TILE_SIZE_Q8_0;
case HTP_TYPE_MXFP4:
return HTP_MM_WEIGHT_TILE_SIZE_MXFP4;
default:
return 0;
}
}
static inline uint32_t htp_mm_get_weight_aligned_tile_size(int weight_type) {
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0;
case HTP_TYPE_Q4_1:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1;
case HTP_TYPE_Q8_0:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0;
case HTP_TYPE_MXFP4:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4;
default:
return 0;
}
}
// --- Activation/Row Size Helpers ---
static inline size_t htp_mm_q8_0_tiled_row_size(uint32_t ne) {
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
const uint32_t nb_32 = ne_padded / 32;
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_0;
}
static inline size_t htp_mm_q8_1_tiled_row_size(uint32_t ne) {
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
const uint32_t nb_32 = ne_padded / 32;
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_1;
}
static inline size_t htp_mm_q8_0_flat_row_size(uint32_t ne) {
const uint32_t quants_size = hex_align_up(ne, 128);
const uint32_t num_scales = (ne + 31) / 32;
const uint32_t scales_size = hex_align_up(num_scales * 2, 128);
return quants_size + scales_size;
}
static inline size_t htp_mm_q8_1_flat_row_size(uint32_t ne) {
const uint32_t quants_size = hex_align_up(ne, 128);
const uint32_t num_scales = (ne + 31) / 32;
const uint32_t scales_size = hex_align_up(num_scales * 4, 128);
return quants_size + scales_size;
}
static inline size_t htp_mm_get_tiled_row_stride(int weight_type, uint32_t k) {
uint32_t nb = (k + QK_Q4_0_TILED - 1) / QK_Q4_0_TILED;
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
case HTP_TYPE_Q4_1:
case HTP_TYPE_Q8_0:
case HTP_TYPE_MXFP4:
return (size_t) nb * htp_mm_get_weight_tile_size(weight_type);
case HTP_TYPE_F16:
return (size_t) k * sizeof(__fp16);
case HTP_TYPE_F32:
return (size_t) k * sizeof(float);
default:
return 0;
}
}
static inline size_t htp_mm_round_up(size_t n, size_t m) {
return ((n + m - 1) / m) * m;
}
static inline bool htp_mm_hmx_pipeline(uint32_t m) {
return m > 32;
}
static inline void htp_mm_hmx_get_2d_chunk_costs(
int wtype, uint32_t k, bool pipeline, uint32_t aligned_tile_size,
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
) {
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
const size_t vec_dot_size = k * sizeof(uint16_t);
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0;
*size_per_n_out = (pipeline ? 2 : 1) * (is_quant ? qweight_row_stride : row_stride) +
(pipeline ? 2 * vec_dot_size : vec_dot_size);
*size_per_m_out = vec_dot_size;
*size_per_mn_out = (pipeline ? 2 : 1) * sizeof(uint16_t);
}
static inline void htp_mm_hmx_get_batched_chunk_costs(
uint32_t k, uint32_t group_size,
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
) {
const size_t vec_dot_size = k * sizeof(uint16_t);
*size_per_n_out = 3 * vec_dot_size;
*size_per_m_out = group_size * vec_dot_size;
*size_per_mn_out = sizeof(uint16_t);
}
static inline size_t htp_mm_hmx_get_2d_vtcm_size(
int wtype, uint32_t k, size_t mc, size_t nc, bool pipeline, uint32_t act_threads, uint32_t aligned_tile_size
) {
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
const size_t vec_dot_size = k * sizeof(uint16_t);
const size_t act_f32_size = htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE);
size_t weight_area_size = is_quant
? htp_mm_round_up((nc / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE)
: htp_mm_round_up(nc * row_stride, HTP_MM_HMX_TILE_SIZE);
if (pipeline) {
weight_area_size *= 2;
}
const size_t act_area_size = htp_mm_round_up(mc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
const size_t output_area_size = htp_mm_round_up(mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
size_t scratch0_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
size_t scratch1_size = pipeline ? scratch0_size : 0;
size_t scratch2_size = pipeline ? output_area_size : 0;
return weight_area_size + act_area_size + act_f32_size + output_area_size +
scratch0_size + scratch1_size + scratch2_size + 256;
}
static inline size_t htp_mm_hmx_get_batched_vtcm_size(
int wtype, uint32_t k, size_t mc, size_t nc, uint32_t group_size, bool use_dma_activation, bool pipeline, uint32_t act_threads) {
(void)wtype;
(void)pipeline;
const size_t vec_dot_size = k * sizeof(uint16_t);
const size_t f32_scratch_size = use_dma_activation
? htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0;
const size_t act_head_stride = mc * k;
const size_t weight_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
const size_t act_area_size = htp_mm_round_up(group_size * act_head_stride * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
const size_t output_area_size = htp_mm_round_up(group_size * mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
const size_t scratch_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
return weight_area_size + act_area_size + output_area_size +
2 * scratch_area_size + 256 + f32_scratch_size;
}
static inline size_t htp_mm_hvx_get_vtcm_sizes(
int kernel_type,
int wtype,
uint32_t ne10, // k
uint32_t src1_nrows, // m_total (or act_nrows)
uint32_t n_threads,
size_t dst_row_size,
size_t src0_row_size,
size_t src1_row_size,
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out,
size_t * vtcm_dst_size_out
) {
size_t vtcm_src0_size = 0;
size_t vtcm_src1_size = 0;
size_t vtcm_dst_size = 0;
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
wtype == HTP_TYPE_MXFP4);
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
const size_t dst_nrows = (src1_nrows > 1) ? 0 : 1;
switch (kernel_type) {
case HTP_MM_KERNEL_HVX_F16_F16_VTCM: {
size_t f16_src1_row_size = htp_mm_round_up(ne10 * 2, 128);
vtcm_src1_size = htp_mm_round_up(f16_src1_row_size * src1_nrows, 256);
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
break;
}
case HTP_MM_KERNEL_HVX_F16_F32_DDR:
case HTP_MM_KERNEL_HVX_F16_F16_DDR:
case HTP_MM_KERNEL_HVX_F32_F32_DDR:
case HTP_MM_KERNEL_HVX_F32_F16_DDR: {
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size, 256) * n_threads;
vtcm_src1_size = htp_mm_round_up(n_prefetch * src1_row_size, 256) * n_threads;
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
break;
}
case HTP_MM_KERNEL_HVX_F32_F32_VTCM: {
size_t f32_src1_row_size = htp_mm_round_up(ne10 * 4, 128);
vtcm_src1_size = htp_mm_round_up(f32_src1_row_size * src1_nrows, 256);
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
break;
}
case HTP_MM_KERNEL_HVX_QUANT_BLOCK:
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
break;
}
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
break;
}
default:
break;
}
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = vtcm_src1_size;
*vtcm_dst_size_out = vtcm_dst_size;
return vtcm_src0_size + vtcm_src1_size + vtcm_dst_size;
}
static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
int wtype,
uint32_t ne10, // k
uint32_t src1_nrows,
uint32_t n_threads,
size_t src0_row_size, // nb01
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out
) {
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
wtype == HTP_TYPE_MXFP4);
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
const size_t src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10)
: htp_mm_q8_0_tiled_row_size(ne10);
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (src0_sz_per_thread < src1_row_size_padded) {
src0_sz_per_thread = src1_row_size_padded;
}
if (is_repack) {
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
const uint32_t n_k_tiles = ne10 / 32;
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
}
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = src1_sz;
return vtcm_src0_size + src1_sz;
}
#ifdef __cplusplus
}
#endif
#endif // HTP_MATMUL_OPS_H
+4
View File
@@ -14,6 +14,8 @@ Drivers_Dir = 13
1 = %DiskId%
[SourceDisksFiles]
libggml-htp-v68.so = 1
libggml-htp-v69.so = 1
libggml-htp-v73.so = 1
libggml-htp-v75.so = 1
libggml-htp-v79.so = 1
@@ -26,6 +28,8 @@ ExcludeFromSelect = *
CopyFiles=Drivers_Dir
[Drivers_Dir]
libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+119 -51
View File
@@ -24,62 +24,119 @@ if (GGML_METAL_NDEBUG)
endif()
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
set(METALLIB_KERNELS_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/kernels/common.h")
set(METALLIB_KERNELS_DEQUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/dequantize.h")
set(METALLIB_KERNELS_QUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantize.h")
set(METALLIB_KERNEL_SOURCES
kernels/fa.metal
kernels/mul_mv.metal
kernels/mul_mm.metal
kernels/quantize.metal
kernels/softmax.metal
kernels/norm.metal
kernels/unary.metal
kernels/binbcast.metal
kernels/reduce.metal
kernels/tri.metal
kernels/ssm.metal
kernels/wkv.metal
kernels/gated_delta_net.metal
kernels/solve_tri.metal
kernels/rope.metal
kernels/conv.metal
kernels/upscale.metal
kernels/argsort.metal
kernels/pool.metal
kernels/misc.metal
)
if (GGML_METAL_EMBED_LIBRARY)
enable_language(ASM)
add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
# merge ggml-common.h and ggml-metal.metal into a single file
set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
set(METALLIB_EMBED_ASM_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(kind ${src} NAME_WE)
# symbol names must be valid C identifiers ('-' is not allowed)
string(REPLACE "-" "_" kind_sym ${kind})
add_custom_command(
OUTPUT "${METALLIB_EMBED_ASM}"
COMMAND echo "Embedding Metal library"
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
COMMENT "Generate assembly for embedded Metal library"
VERBATIM
)
set(SRC "${CMAKE_CURRENT_SOURCE_DIR}/kernels/${kind}.metal")
set(EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.metal")
set(ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.s")
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
# only prepend headers that this source actually includes
set(HEADERS_FOR_SRC ${METALLIB_KERNELS_COMMON})
file(STRINGS ${SRC} _has_dequantize REGEX "#include \"dequantize\\.h\"")
file(STRINGS ${SRC} _has_quantize REGEX "#include \"quantize\\.h\"")
if(_has_dequantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_DEQUANTIZE})
endif()
if(_has_quantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_QUANTIZE})
endif()
add_custom_command(
OUTPUT "${ASM}"
# Step 1: concatenate shared headers + this kernel source
COMMAND cat ${HEADERS_FOR_SRC} ${SRC} > "${EMBED}.tmp1"
# Step 2: remove internal #include and #pragma once
COMMAND sed -e "/\#include \"common.h\"/d" -e "/\#include \"dequantize.h\"/d" -e "/\#include \"quantize.h\"/d" -e "/\#pragma once/d" < "${EMBED}.tmp1" > "${EMBED}.tmp2"
# Step 3: inline ggml-common.h (replacing __embed_ggml-common.h__ sentinel)
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${EMBED}.tmp2" > "${EMBED}.tmp3"
# Step 4: inline ggml-metal-impl.h
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${EMBED}.tmp3" > "${EMBED}"
# Step 5: emit an asm chunk with kind-specific start/end symbols
# note: '-' is illegal in C symbols, so we use kind_sym; the macOS
# section name is limited to 16 chars so we keep it shared
# across kinds (__ggml_metallib) and only vary the global symbols.
COMMAND echo ".section __DATA,__ggml_metallib" > "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_start" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_start:" >> "${ASM}"
COMMAND echo .incbin "\"${EMBED}\"" >> "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_end" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_end:" >> "${ASM}"
DEPENDS ../ggml-common.h ggml-metal-impl.h
kernels/common.h kernels/dequantize.h kernels/quantize.h
kernels/${kind}.metal
COMMENT "Generate embedded Metal library for ${kind}"
VERBATIM
)
list(APPEND METALLIB_EMBED_ASM_FILES "${ASM}")
endforeach()
target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM_FILES})
else()
# copy metal files to bin directory
# copy header files to bin directory
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
file(MAKE_DIRECTORY "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels")
configure_file(kernels/common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/common.h COPYONLY)
configure_file(kernels/dequantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/dequantize.h COPYONLY)
configure_file(kernels/quantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/quantize.h COPYONLY)
foreach(src ${METALLIB_KERNEL_SOURCES})
configure_file(${src} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} COPYONLY)
endforeach()
if (GGML_METAL_SHADER_DEBUG)
# custom command to do the following:
# xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
# xcrun -sdk macosx metallib ggml-metal.air -o default.metallib
#
# note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
# disabling fast math is needed in order to pass tests/test-backend-ops
# note: disabling fast math is needed in order to pass tests/test-backend-ops
# note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
# note: unfortunately, we have to call it default.metallib instead of ggml.metallib
# ref: https://github.com/ggml-org/whisper.cpp/issues/1720
# note: adding -g causes segmentation fault during compile
#set(XC_FLAGS -fno-fast-math -fno-inline -g)
set(XC_FLAGS -fno-fast-math -fno-inline)
else()
set(XC_FLAGS -O3)
endif()
# Append macOS metal versioning flags
if (GGML_METAL_MACOSX_VERSION_MIN)
message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
@@ -90,35 +147,46 @@ else()
list (APPEND XC_FLAGS -std=${GGML_METAL_STD})
endif()
# Compile each kernel source to .air, then link into default.metallib
set(AIR_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(name ${src} NAME_WE)
set(AIR "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${name}.air")
list(APPEND AIR_FILES ${AIR})
add_custom_command(
OUTPUT ${AIR}
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -I ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} -o ${AIR}
DEPENDS ${src} kernels/common.h kernels/dequantize.h kernels/quantize.h ${METALLIB_COMMON} ggml-metal-impl.h
COMMENT "Compiling ${src}"
VERBATIM
)
endforeach()
add_custom_command(
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metallib ${AIR_FILES} -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
COMMENT "Compiling Metal kernels"
)
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h
COMMAND rm -rf ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels
DEPENDS ${AIR_FILES}
COMMENT "Linking Metal kernels into default.metallib"
)
# FIXME: only add to the ggml-metal target?
add_custom_target(
ggml-metal-lib ALL
DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
)
)
endif() # GGML_METAL_EMBED_LIBRARY
if (NOT GGML_METAL_EMBED_LIBRARY)
install(
FILES src/ggml-metal/ggml-metal.metal
PERMISSIONS
OWNER_READ
OWNER_WRITE
GROUP_READ
WORLD_READ
DESTINATION ${CMAKE_INSTALL_BINDIR})
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/kernels/
DESTINATION ${CMAKE_INSTALL_BINDIR}/kernels
FILES_MATCHING PATTERN "*.metal" PATTERN "*.h"
)
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
endif()
+423 -128
View File
@@ -94,8 +94,63 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
}
//
// MTLLibrary collection (one library per op-source, compiled separately)
//
// Single source of truth for the per-kind metal libraries. The order here
// defines the enum values and every per-kind table below, so adding a library
// is a one-line change here (plus adding its source to CMakeLists.txt).
// X(suffix, name): name is both the kernels/<name>.metal basename and the
// ggml_metallib_<name>_{start,end} embed-symbol stem.
#define GGML_METAL_LIBS \
X(FA, fa) \
X(MUL_MV, mul_mv) \
X(MUL_MM, mul_mm) \
X(QUANTIZE, quantize) \
X(SOFTMAX, softmax) \
X(NORM, norm) \
X(UNARY, unary) \
X(BINBCAST, binbcast) \
X(REDUCE, reduce) \
X(TRI, tri) \
X(SSM, ssm) \
X(WKV, wkv) \
X(GATED_DELTA_NET, gated_delta_net)\
X(SOLVE_TRI, solve_tri) \
X(ROPE, rope) \
X(CONV, conv) \
X(UPSCALE, upscale) \
X(ARGSORT, argsort) \
X(POOL, pool) \
X(MISC, misc)
enum ggml_metal_lib_kind {
#define X(e, s) GGML_METAL_LIB_##e,
GGML_METAL_LIBS
#undef X
GGML_METAL_LIB_COUNT,
};
static const char * const k_lib_names[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = #s,
GGML_METAL_LIBS
#undef X
};
struct ggml_metal_library {
id<MTLLibrary> obj;
// Per-kind compiled libraries. When single_library is true, the whole library
// (e.g. a pre-compiled default.metallib or a from-source build) lives at
// objs[0] and the remaining slots are nil.
id<MTLLibrary> objs[GGML_METAL_LIB_COUNT];
bool single_library; // true: combined library at objs[0]; false: per-kind libs in objs[*]
// Routing table: kernel function name -> objs[] index, populated from each
// compiled library's -[MTLLibrary functionNames]. The actual compiled
// libraries are the single source of truth for which library owns a kernel,
// so adding kernels later requires no manual routing maintenance.
// nil in single_library mode (everything resolves to objs[0]).
NSMutableDictionary<NSString *, NSNumber *> * fn_to_lib;
ggml_metal_device_t dev;
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
@@ -103,160 +158,376 @@ struct ggml_metal_library {
NSLock * lock;
};
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLLibrary> library = nil;
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
// Build the fn_to_lib routing table by querying each compiled library's public
// function names. Call once after all per-kind libraries have been compiled.
static void ggml_metal_library_build_index(ggml_metal_library_t lib) {
@autoreleasepool {
NSMutableDictionary<NSString *, NSNumber *> * index = [[NSMutableDictionary alloc] init];
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
for (NSString * fname in [lib->objs[kind] functionNames]) {
index[fname] = @(kind);
}
}
lib->fn_to_lib = index;
}
}
// load library
//
// - first check if the library is embedded
// - then check if the library is in the bundle
// - if not found, load the source and compile it
// - if that fails, return NULL
//
// TODO: move to a function
{
const int64_t t_start = ggml_time_us();
// Parse a `#include "name"` line. Returns the quoted name in *include_name on
// success. Whitespace-tolerant; ignores `#include <...>` (system headers).
static bool ggml_metal_library_parse_quoted_include(NSString * line, NSString ** include_name) {
NSScanner * scanner = [NSScanner scannerWithString:line];
scanner.charactersToBeSkipped = [NSCharacterSet whitespaceCharacterSet];
NSError * error = nil;
NSString * src = nil;
if (![scanner scanString:@"#" intoString:NULL] ||
![scanner scanString:@"include" intoString:NULL] ||
![scanner scanString:@"\"" intoString:NULL]) {
return false;
}
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
NSString * name = nil;
if (![scanner scanUpToString:@"\"" intoString:&name]) {
return false;
}
extern const char ggml_metallib_start[];
extern const char ggml_metallib_end[];
if (include_name) {
*include_name = name;
}
return true;
}
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
#else
// Recursively inline `#include "name"` directives. System includes (<...>),
// `#if/#else/#endif`, and other preprocessor lines are passed through to the
// Metal compiler unchanged. `#pragma once` is dropped since `seen` already
// guards against double-inclusion.
static bool ggml_metal_library_flatten_file(NSMutableString * dst, NSString * path,
NSArray<NSString *> * search_paths,
NSMutableSet<NSString *> * seen, NSError ** error) {
NSString * key = [path stringByStandardizingPath];
if ([seen containsObject:key]) {
return true;
}
[seen addObject:key];
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:error];
if (!src) {
return false;
}
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSFileManager * fm = [NSFileManager defaultManager];
for (NSString * line in [src componentsSeparatedByString:@"\n"]) {
NSString * trimmed = [line stringByTrimmingCharactersInSet:[NSCharacterSet whitespaceCharacterSet]];
if ([trimmed isEqualToString:@"#pragma once"]) {
continue;
}
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
NSString * include_name = nil;
if (ggml_metal_library_parse_quoted_include(line, &include_name)) {
NSString * resolved = nil;
for (NSString * dir in search_paths) {
NSString * candidate = [dir stringByAppendingPathComponent:include_name];
if ([fm isReadableFileAtPath:candidate]) {
resolved = candidate;
break;
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
path_lib = path_lib_default;
if (!resolved) {
if (error) {
NSString * msg = [NSString stringWithFormat:@"could not resolve include \"%@\" from '%@'", include_name, path];
*error = [NSError errorWithDomain:@"ggml-metal-source-flatten" code:1
userInfo:@{NSLocalizedDescriptionKey: msg}];
}
return false;
}
if (!ggml_metal_library_flatten_file(dst, resolved, search_paths, seen, error)) {
return false;
}
continue;
}
if (path_lib != nil) {
// pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
[dst appendString:line];
[dst appendString:@"\n"];
}
library = [device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
} else {
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
return true;
}
NSString * path_source;
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
static NSString * ggml_metal_library_flatten_source(NSString * path_source, NSError ** error) {
// Search paths cover both runtime layout (build/bin/kernels + build/bin)
// and source-tree layout (ggml/src/ggml-metal/kernels + ggml/src/ggml-metal + ggml/src).
NSString * path_kernels = [path_source stringByDeletingLastPathComponent];
NSString * path_base = [path_kernels stringByDeletingLastPathComponent];
NSArray<NSString *> * search_paths = @[
path_kernels,
path_base,
[path_base stringByDeletingLastPathComponent],
];
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
NSMutableString * src = [[NSMutableString alloc] init];
NSMutableSet<NSString *> * seen = [NSMutableSet set];
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
} else {
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
if (!ggml_metal_library_flatten_file(src, path_source, search_paths, seen, error)) {
[src release];
return nil;
}
return src;
}
// Compile all per-kind libraries in parallel. `source_for_kind` returns the MSL
// source for a kind (the helper takes ownership and releases it), or nil with
// *err set on failure. On success the objs[] slots are populated and the routing
// index is built; on any failure every error is logged and false is returned
// (the caller is responsible for freeing `res`).
static bool ggml_metal_library_compile_all(
ggml_metal_library_t res,
id<MTLDevice> device,
NSDictionary * prep,
NSString * (^source_for_kind)(int kind, NSError ** err),
const char * origin) {
const int64_t t_start = ggml_time_us();
int64_t * t_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(int64_t));
NSError ** err_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(NSError *));
__block atomic_bool any_failure = false;
dispatch_group_t group = dispatch_group_create();
dispatch_queue_t queue = dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0);
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
dispatch_group_async(group, queue, ^{
const int64_t t0 = ggml_time_us();
NSError * error = nil;
NSString * src = source_for_kind(kind, &error);
if (!src) {
err_per_lib[kind] = [error retain];
atomic_store(&any_failure, true);
return;
}
if (path_source == nil) {
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
path_source = @"ggml-metal.metal";
}
id<MTLLibrary> lib = nil;
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
}
#endif
if (!library) {
@autoreleasepool {
// dictionary of preprocessor macros
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
MTLCompileOptions * options = [MTLCompileOptions new];
options.preprocessorMacros = prep;
//[options setFastMathEnabled:false];
lib = [device newLibraryWithSource:src options:options error:&error];
library = [device newLibraryWithSource:src options:options error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
#if !__has_feature(objc_arc)
[options release];
#endif
// retain the error before the autorelease pool drains it
if (!lib) {
err_per_lib[kind] = [error retain];
}
}
[src release];
t_per_lib[kind] = ggml_time_us() - t0;
if (!lib) {
atomic_store(&any_failure, true);
return;
}
res->objs[kind] = lib;
});
}
dispatch_group_wait(group, DISPATCH_TIME_FOREVER);
dispatch_release(group);
const bool ok = !atomic_load(&any_failure);
if (ok) {
const int64_t t_total = ggml_time_us() - t_start;
int64_t t_max = 0;
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
GGML_LOG_DEBUG("%s: compiled '%s' library in %.3f sec\n",
__func__, k_lib_names[kind], t_per_lib[kind] / 1e6);
if (t_per_lib[kind] > t_max) t_max = t_per_lib[kind];
}
GGML_LOG_INFO("%s: loaded %d libraries from %s in %.3f sec (max single = %.3f sec)\n",
__func__, GGML_METAL_LIB_COUNT, origin, t_total / 1e6, t_max / 1e6);
ggml_metal_library_build_index(res);
} else {
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (err_per_lib[kind]) {
GGML_LOG_ERROR("%s: failed to build '%s' library: %s\n", __func__,
k_lib_names[kind], [[err_per_lib[kind] description] UTF8String]);
[err_per_lib[kind] release];
}
}
#if GGML_METAL_EMBED_LIBRARY
[src release];
#endif // GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
}
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
free(err_per_lib);
free(t_per_lib);
res->obj = library;
return ok;
}
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
// shared MTLCompileOptions preprocessor macros (matches the build-time defines)
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
// start/end symbols emitted by CMake (see CMakeLists.txt), one pair per kind
#define X(e, s) extern const char ggml_metallib_##s##_start[]; extern const char ggml_metallib_##s##_end[];
GGML_METAL_LIBS
#undef X
static const char * const lib_start[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_start,
GGML_METAL_LIBS
#undef X
};
static const char * const lib_end[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_end,
GGML_METAL_LIBS
#undef X
};
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
(void) err;
return [[NSString alloc] initWithBytes:lib_start[kind]
length:(lib_end[kind] - lib_start[kind])
encoding:NSUTF8StringEncoding];
}, "embedded data");
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#else
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
const int64_t t_start = ggml_time_us();
NSError * error = nil;
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
path_lib = path_lib_default;
}
if (path_lib != nil) {
// pre-compiled library found: a single combined default.metallib
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
res->objs[0] = [device newLibraryWithURL:libURL error:&error];
res->single_library = true;
if (!res->objs[0]) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
ggml_metal_library_free(res);
return NULL;
}
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
return res;
}
// no pre-compiled metallib: fall back to compiling each kernel source separately
GGML_LOG_INFO("%s: default.metallib not found, loading kernel sources\n", __func__);
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
if (path_resource) {
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, [path_resource UTF8String]);
}
// resolve each kind's source path up front (file lookup/logging stays on the calling thread)
NSString ** path_per_kind = calloc(GGML_METAL_LIB_COUNT, sizeof(NSString *));
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
NSString * rel = [NSString stringWithFormat:@"kernels/%s.metal", k_lib_names[kind]];
NSString * path_source = nil;
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:rel];
} else {
NSString * stem = [NSString stringWithFormat:@"kernels/%s", k_lib_names[kind]];
path_source = [bundle pathForResource:stem ofType:@"metal"];
}
if (path_source == nil || ![[NSFileManager defaultManager] isReadableFileAtPath:path_source]) {
GGML_LOG_WARN("%s: could not locate %s in bundle, falling back to cwd\n", __func__, [rel UTF8String]);
path_source = rel;
}
GGML_LOG_DEBUG("%s: loading '%s'\n", __func__, [path_source UTF8String]);
path_per_kind[kind] = [path_source retain];
}
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
return ggml_metal_library_flatten_source(path_per_kind[kind], err);
}, "source");
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
[path_per_kind[kind] release];
}
free(path_per_kind);
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#endif
}
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
@@ -318,10 +589,11 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
return NULL;
}
res->obj = library;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
res->objs[0] = library;
res->single_library = true;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
return res;
}
@@ -331,8 +603,14 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
return;
}
if (lib->obj) {
[lib->obj release];
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (lib->objs[kind]) {
[lib->objs[kind] release];
}
}
if (lib->fn_to_lib) {
[lib->fn_to_lib release];
}
ggml_metal_pipelines_free(lib->pipelines);
@@ -393,11 +671,28 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
// route to the library that actually defines this kernel; fn_to_lib is
// built from -[MTLLibrary functionNames] so it's always in sync
int lib_idx = 0;
if (!lib->single_library) {
NSNumber * idx = lib->fn_to_lib[base_func];
if (!idx) {
[lib->lock unlock];
GGML_LOG_ERROR("%s: kernel not found in any metal library: base = '%s', name = '%s'\n", __func__, base, name);
return res;
}
lib_idx = [idx intValue];
}
id<MTLLibrary> mtl_lib = lib->objs[lib_idx];
id<MTLFunction> mtl_function;
if (!cv) {
mtl_function = [lib->obj newFunctionWithName:base_func];
mtl_function = [mtl_lib newFunctionWithName:base_func];
} else {
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
mtl_function = [mtl_lib newFunctionWithName:base_func constantValues:cv->obj error:&error];
}
if (!mtl_function) {
[lib->lock unlock];
File diff suppressed because it is too large Load Diff
+232
View File
@@ -0,0 +1,232 @@
#include "common.h"
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_f32_i32(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
const int col = tpitg[0];
const int ib = tgpig[0] / args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
// initialize indices
shmem_i32[col] = i00 + col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= ntg.x; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (shmem_i32[col] >= args.ne00 ||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
} else {
if (shmem_i32[ixj] >= args.ne00 ||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
const int64_t i0 = ib*args.top_k;
// copy the result to dst without the padding
if (i0 + col < args.ne0 && col < args.top_k) {
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col];
}
}
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
typedef void (argsort_merge_t)(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_merge_f32_i32(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int im = tgpig[0] / args.ne01;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start
+ i01*args.ne0
+ i02*args.ne0*args.ne01
+ i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len;
dst += start
+ i01*args.top_k
+ i02*args.top_k*args.ne01
+ i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0
+ args.nb01*i01
+ args.nb02*i02
+ args.nb03*i03);
if (total == 0) {
return;
}
const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk;
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
if (k0 >= args.top_k) {
return;
}
if (k0 >= total) {
return;
}
int low = k0 > len1 ? k0 - len1 : 0;
int high = MIN(k0, len0);
// binary-search partition (i, j) such that i + j = k
while (low < high) {
const int mid = (low + high) >> 1;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k0 - mid - 1];
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
low = mid + 1;
} else {
high = mid;
}
}
int i = low;
int j = k0 - i;
// keep the merge fronts into registers
int32_t idx0 = 0;
float val0 = 0.0f;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
int32_t idx1 = 0;
float val1 = 0.0f;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
for (int k = k0; k < k1; ++k) {
int32_t out_idx;
if (i >= len0) {
while (k < k1) {
dst[k++] = tmp1[j++];
}
break;
} else if (j >= len1) {
while (k < k1) {
dst[k++] = tmp0[i++];
}
break;
} else {
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
out_idx = idx0;
++i;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
} else {
out_idx = idx1;
++j;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
}
}
dst[k] = out_idx;
}
}
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
+226
View File
@@ -0,0 +1,226 @@
#include "common.h"
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
template <typename T0, typename T1, typename T>
kernel void kernel_bin_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_bin_op
#define FC_F FC_bin_f
#define FC_RB FC_bin_rb
#define FC_CB FC_bin_cb
if (FC_RB) {
// row broadcast
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
device const T0 * src0_row = (device const T0 *) (src0);
device T * dst_row = (device T *) (dst);
if (FC_F == 1) {
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
if (FC_OP == 0) {
dst_row[i0] = src0_row[i0] + src1_row[i1];
}
if (FC_OP == 1) {
dst_row[i0] = src0_row[i0] - src1_row[i1];
}
if (FC_OP == 2) {
dst_row[i0] = src0_row[i0] * src1_row[i1];
}
if (FC_OP == 3) {
dst_row[i0] = src0_row[i0] / src1_row[i1];
}
} else {
T0 res = src0_row[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
dst_row[i0] = res;
}
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (i01 >= args.ne01) {
return;
}
const int i13 = i03%args.ne13;
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
if (FC_F == 1) {
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
if (FC_OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
}
if (FC_OP == 1) {
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
}
if (FC_OP == 2) {
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
}
if (FC_OP == 3) {
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
}
}
} else {
device const T1 * src1_ptr[8];
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
T res = src0_ptr[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += src1_ptr[j][i10];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= src1_ptr[j][i10];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= src1_ptr[j][i10];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= src1_ptr[j][i10];
}
}
dst_ptr[i0] = res;
}
}
}
#undef FC_OP
#undef FC_F
#undef FC_RB
#undef FC_CB
}
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
kernel void kernel_add_id(
constant ggml_metal_kargs_add_id & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i1 = tgpig.x;
const int i2 = tgpig.y;
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
const size_t nb1 = args.ne0 * sizeof(float);
const size_t nb2 = args.ne1 * nb1;
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
template<typename T>
kernel void kernel_repeat(
constant ggml_metal_kargs_repeat & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
const int i03 = i3%args.ne03;
const int i02 = i2%args.ne02;
const int i01 = i1%args.ne01;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i00 = i0%args.ne00;
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
}
}
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_repeat_bf16")]] kernel kernel_repeat_t kernel_repeat<bfloat>;
#endif
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
+126
View File
@@ -0,0 +1,126 @@
#pragma once
#include "ggml-metal-impl.h"
#include <metal_stdlib>
#ifdef GGML_METAL_HAS_TENSOR
#include <metal_tensor>
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
#endif
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
//
// cmd:
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/kernels/<src>.metal
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/kernels/<src>.metal
//
#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
#undef GGML_METAL_HAS_BF16
#endif
#if defined(GGML_METAL_HAS_BF16)
typedef matrix<bfloat, 4, 4> bfloat4x4;
typedef matrix<bfloat, 2, 4> bfloat2x4;
#endif
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
constexpr constant static float kvalues_mxfp4_f[16] = {
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
};
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static inline float e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = (uint32_t) x << 23;
}
return as_type<float>(bits);
}
static inline float dot(float x, float y) {
return x*y;
}
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
GGML_SORT_ORDER_DESC,
};
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
template<typename T> T elu_approx(T x);
template<> inline float elu_approx<float>(float x) {
return (x > 0.f) ? x : (exp(x) - 1);
}
template<> inline float4 elu_approx<float4>(float4 x) {
float4 res;
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
return res;
}
+485
View File
@@ -0,0 +1,485 @@
#include "common.h"
typedef void (im2col_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// const int64_t IC = tgpg[0];
const int64_t OH = tgpg[1];
const int64_t OW = tgpg[2];
const int64_t KH = ntg[1];
const int64_t KW = ntg[2];
int64_t in = tpitg[0];
const int64_t ikh = tpitg[1];
const int64_t ikw = tpitg[2];
const int64_t iic = tgpig[0];
const int64_t ioh = tgpig[1];
const int64_t iow = tgpig[2];
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
while (in < args.N) {
pdst[offset_dst] = 0.0f;
offset_dst += ntg[0]*args.CHW*OH*OW;
in += ntg[0];
}
} else {
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
while (in < args.N) {
pdst[offset_dst] = x[offset_src];
offset_dst += ntg[0]*args.CHW*OH*OW;
offset_src += ntg[0]*args.ofs0;
in += ntg[0];
}
}
}
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
// TODO: optimize
typedef void (im2col_ext_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col_ext(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int64_t KHW = (int64_t)args.KHW;
const int64_t d = tgpig[0] / args.CHW;
const int64_t chw = tgpig[0] % args.CHW;
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int64_t HW = tgpig[0] % KHW;
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= args.N) {
return;
}
const int64_t tpitg_1 = HW / args.KW;
const int64_t tpitg_2 = HW % args.KW;
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
const int64_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
pdst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
}
}
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
template <typename TK>
kernel void kernel_conv_2d(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
const uint thread_index = tg_index * threads_per_tg + local_thread;
const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
uint64_t tmp = index;
const int32_t ow = tmp % args.OW; tmp /= args.OW;
const int32_t oh = tmp % args.OH; tmp /= args.OH;
const int32_t oc = tmp % args.OC; tmp /= args.OC;
const int32_t n = tmp;
float acc = 0.0f;
const int32_t base_x = ow*args.s0 - args.p0;
const int32_t base_y = oh*args.s1 - args.p1;
int32_t ky_start = 0;
if (base_y < 0) {
ky_start = (-base_y + args.d1 - 1)/args.d1;
}
int32_t ky_end = args.KH;
const int32_t y_max = args.IH - 1 - base_y;
if (y_max < 0) {
ky_end = ky_start;
} else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
ky_end = min(ky_end, y_max/args.d1 + 1);
}
int32_t kx_start = 0;
if (base_x < 0) {
kx_start = (-base_x + args.d0 - 1)/args.d0;
}
int32_t kx_end = args.KW;
const int32_t x_max = args.IW - 1 - base_x;
if (x_max < 0) {
kx_end = kx_start;
} else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
kx_end = min(kx_end, x_max/args.d0 + 1);
}
if (ky_start < ky_end && kx_start < kx_end) {
const uint64_t src_base_n = (uint64_t) n * args.nb13;
const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
for (int32_t ic = 0; ic < args.IC; ++ic) {
const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
for (int32_t ky = ky_start; ky < ky_end; ++ky) {
const int32_t iy = base_y + ky*args.d1;
const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
for (int32_t kx = kx_start; kx < kx_end; ++kx) {
const int32_t ix = base_x + kx*args.d0;
const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
const float x = *(device const float *)(src + src_offs);
const float w = (float) (*(device const TK *)(weights + w_offs));
acc += x * w;
}
}
}
}
const uint64_t dst_offs =
(uint64_t) n * args.nb3 +
(uint64_t) oc * args.nb2 +
(uint64_t) oh * args.nb1 +
(uint64_t) ow * args.nb0;
*(device float *)(dst + dst_offs) = acc;
}
}
template [[host_name("kernel_conv_2d_f32_f32")]]
kernel void kernel_conv_2d<float>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_2d_f16_f32")]]
kernel void kernel_conv_2d<half>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
typedef void (conv_transpose_1d_t)(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_1d(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const T * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]) {
// For output position j on the time axis, only input positions
// i such that i*s0 <= j < i*s0 + K
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
// (typically 2 for stride==K/2 transposed convs).
const int32_t j = tgpig[0];
const int32_t s0 = args.s0;
const int32_t K = args.K;
const int32_t IL = args.IL;
int32_t i_min;
{
int32_t a = j - K + 1;
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
}
int32_t i_max = j / s0;
if (i_max > IL - 1) i_max = IL - 1;
float v = 0.0f;
if (i_min <= i_max) {
for (int64_t c = 0; c < args.IC; c++) {
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
const int32_t input_offset = c * IL;
for (int32_t i = i_min; i <= i_max; i++) {
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
}
}
}
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
dst_ptr[0] = v;
}
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
kernel void kernel_conv_transpose_1d<float>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
kernel void kernel_conv_transpose_1d<half>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const half * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_conv_3d(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0, // Weights [IC * OC, KD, KH, KW]
device const char * src1, // Inputs [IC * N, ID, IH, IW]
device char * dst, // Outputs [OC * N, OD, OH, OW]
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
// 1. Un-flatten the spatial dimension from Grid X
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
if (spatial_idx >= args.OW * args.OH * args.OD) {
return; // Thread falls outside the spatial volume
}
int64_t od = spatial_idx / (args.OW * args.OH);
int64_t oh = (spatial_idx / args.OW) % args.OH;
int64_t ow = spatial_idx % args.OW;
// 2. Map Y to Channels, Z to Batch
int64_t oc = tgpig.y;
int64_t batch_idx = tgpig.z;
// 3. Calculate anchor coordinates in the Input volume
int64_t i_w_base = ow * args.s0 - args.p0;
int64_t i_h_base = oh * args.s1 - args.p1;
int64_t i_d_base = od * args.s2 - args.p2;
float sum = 0.0f;
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
for (int64_t ic = 0; ic < args.IC; ++ic) {
// ggml packs batch and channel together in the 4th dimension
int64_t src_cn_idx = batch_idx * args.IC + ic;
int64_t w_cn_idx = oc * args.IC + ic;
for (int64_t kz = 0; kz < args.KD; ++kz) {
int64_t id = i_d_base + kz * args.d2;
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
for (int64_t ky = 0; ky < args.KH; ++ky) {
int64_t ih = i_h_base + ky * args.d1;
if (ih < 0 || ih >= args.IH) continue;
for (int64_t kx = 0; kx < args.KW; ++kx) {
int64_t iw = i_w_base + kx * args.d0;
if (iw < 0 || iw >= args.IW) continue;
// Convert multi-dimensional coordinates to flat byte offsets
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
// Dereference memory and cast weights to f32 if they were f16
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
float i_val = *(device const float*)((device const char*)src1 + i_idx);
sum += w_val * i_val;
}
}
}
}
// 5. Write the accumulated value out to RAM
int64_t dst_cn_idx = batch_idx * args.OC + oc;
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
*(device float*)(dst + d_idx) = sum;
}
// Explicit instantiations so the JIT compiler can find them by name
template [[host_name("kernel_conv_3d_f32_f32")]]
kernel void kernel_conv_3d<float>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
// Explicit instantiation for f16 weights
template [[host_name("kernel_conv_3d_f16_f32")]]
kernel void kernel_conv_3d<half>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
+686
View File
@@ -0,0 +1,686 @@
#pragma once
#include "common.h"
#define GGML_COMMON_DECL_METAL
#define GGML_COMMON_IMPL_METAL
#if defined(GGML_METAL_EMBED_LIBRARY)
__embed_ggml-common.h__
#else
#include "ggml-common.h"
#endif
#define QK_NL 16 // shared by mul_mm and get_rows_q instantiations
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
reg = (type4)(*src);
}
template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#if defined(GGML_METAL_HAS_BF16)
template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#endif
template <typename type4x4>
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
device const uint8_t * qs = xb->qs;
const float d = xb->d;
const float neg_d = -d;
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
const uint8_t b0 = qs[byte_offset];
const uint8_t b1 = qs[byte_offset + 1];
float4x4 reg_f;
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
const float d = xb->d;
const float neg_d = -d;
const int base = il * 4;
const uint8_t byte = xb->qs[base / 8];
const int s = base % 8;
float4 reg_f;
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
reg = (type4) reg_f;
}
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
}
}
template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
}
}
template <typename type4x4>
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + md;
reg[2*ii + 1] = d * x1 + md;
}
}
template <typename type4x4>
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + m;
reg[2*ii + 1] = d * x1 + m;
}
}
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
float4x4 reg_f;
for (int i = 0; i < 16; i++) {
reg_f[i/4][i%4] = (qs[i + 16*il] * d);
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
}
}
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const uint8_t shr = il >= 1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
}
}
template <typename type4>
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const short il4 = il%4;
const uint8_t shr = il >= 4 ? 4 : 0;
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
const float min = xb->dmin;
device const uint8_t * q = (device const uint8_t *)xb->qs;
float dl, ml;
uint8_t sc = xb->scales[il];
q = q + 32*(il/8) + 16*(il&1);
il = (il/2)%4;
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales;
q = q + 32 * (il/8) + 16 * (il&1);
h = h + 16 * (il&1);
uint8_t m = 1 << (il/2);
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
const float ml = 4.f * dl;
il = (il/2) & 3;
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl *= coef;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
}
}
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
}
template <typename type4x4>
void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
device const uchar * q = xb->qs;
short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.h;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il < 2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
device const uint8_t * q = xb->qs;
device const uint8_t * qh = xb->qh;
short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.f;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il<2 ? 0x0F : 0xF0;
const float qh_val = il<2 ? 16.f : 256.f;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
}
template <typename type4x4>
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint16_t * ql = (device const uint16_t *)xb->ql;
device const uint16_t * qh = (device const uint16_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales;
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
qh = qh + 16*(il/8) + 8*(il&1);
float sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
const float ml = d_all * sc * 32.f;
const float dl0 = d_all * sc;
const float dl1 = dl0 / 256.f;
const float dl2 = dl0 / (256.f * 256.f);
const float dl3 = dl0 / (256.f * 256.f * 256.f);
const uint8_t shr_h = il>2 ? 2 : 0;
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
const uint8_t shr_l = il>1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
}
}
template <typename type4x4>
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
device const uint16_t * q2 = xb->qs + 4*ib32;
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint16_t * q2 = xb->qs + 4*ib32;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * q3 = xb->qs + 8*ib32;
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 8*ib32;
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
}
}
template <typename type4x4>
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * signs = qs + QK_K/8;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
for (int i = 0; i < 8; ++i) {
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
}
}
template <typename type4x4>
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
const float d = xb->d;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint16_t * qh = xb->qh;
const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
const uint16_t h = qh[ib32] >> 6*il;
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml;
reg[1][i] = dl * (grid1[i] >> 4) + ml;
reg[2][i] = dl * (grid2[i] & 0xf) + ml;
reg[3][i] = dl * (grid2[i] >> 4) + ml;
}
}
template <typename type4x4>
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
device const uint16_t * sc = (device const uint16_t *)xb->scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = scale.f16;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * qh = xb->qh + 2*ib32 + il;
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
}
}
template <typename type4x4>
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
template <typename type4>
void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
reg[0] = d * kvalues_iq4nl_f[q8[0]];
reg[1] = d * kvalues_iq4nl_f[q8[1]];
reg[2] = d * kvalues_iq4nl_f[q8[2]];
reg[3] = d * kvalues_iq4nl_f[q8[3]];
}
template <typename type4x4>
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
const float d = (float)xb->d * (ls - 32);
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,250 @@
#include "common.h"
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
#if 1
template<short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
#define K FC_gated_delta_net_K
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B (n_seqs)
const uint i21 = tgpig.y; // H (head)
const uint i20 = tgpig.x*NSG + ty; // row within S_v
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
device const float * s_ptr = (device const float *) (s) + state_in_base;
float ls[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] = s_ptr[is];
}
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
// output state base offset: after attention scores
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
// output state per-slot size: S_v * S_v * H * n_seqs
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
// per-(seq,head) offset within a slot
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
for (short t = 0; t < args.ne22; t++) {
float s_k = 0.0f;
if (G == 1) {
const float g_exp = exp(g_ptr[0]);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= g_exp;
s_k += ls[j]*k_ptr[is];
}
} else {
// KDA
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= exp(g_ptr[is]);
s_k += ls[j]*k_ptr[is];
}
}
s_k = simd_sum(s_k);
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
float y = 0.0f;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] += k_ptr[is]*d;
y += ls[j]*q_ptr[is];
}
y = simd_sum(y);
if (tx == 0) {
dst_attn[t*args.ne21*S_v] = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
if (K > 1) {
const int target_slot = (int)args.ne22 - 1 - (int)t;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
}
}
if (K == 1) {
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
#undef S_v
#undef G
#undef K
}
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
#else
// a simplified version of the above
// no performance improvement, so keep the above version for now
template<typename T, short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
float lsf[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
lsf[j] = s_ptr[is*S_v];
}
thread T * ls = (thread T *) (lsf);
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
for (short t = 0; t < args.ne22; t++) {
device const T * qt_ptr = (device const T *) (q_ptr);
device const T * kt_ptr = (device const T *) (k_ptr);
device const T * gt_ptr = (device const T *) (g_ptr);
if (G == 1) {
*ls *= exp(g_ptr[0]);
} else {
// KDA
*ls *= exp(gt_ptr[tx]);
}
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
*ls += kt_ptr[tx]*d;
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
if (tx == 0) {
*dst_attn = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
dst_attn += args.ne21*S_v;
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
device T * dstt_state = (device T *) (dst_state);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = lsf[j];
}
#undef S_v
#undef G
}
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
#endif
+347
View File
@@ -0,0 +1,347 @@
#include "common.h"
kernel void kernel_argmax_f32(
constant ggml_metal_kargs_argmax & args,
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
float lmax = -INFINITY;
int32_t larg = -1;
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
if (x_row[i00] > lmax) {
lmax = x_row[i00];
larg = i00;
}
}
// find the argmax value in the block
float max_val = simd_max(lmax);
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
device int32_t * dst_i32 = (device int32_t *) dst;
threadgroup float * shared_maxval = (threadgroup float *) shmem;
threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
shared_maxval[tiisg] = -INFINITY;
shared_argmax[tiisg] = -1;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shared_maxval[sgitg] = max_val;
shared_argmax[sgitg] = arg_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = shared_maxval[tiisg];
arg_val = shared_argmax[tiisg];
float max_val_reduced = simd_max(max_val);
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
dst_i32[tgpig] = arg_val_reduced;
return;
}
dst_i32[tgpig] = arg_val;
}
kernel void kernel_diag_f32(
constant ggml_metal_kargs_diag & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]]) {
constexpr short NW = N_SIMDWIDTH;
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
}
}
kernel void kernel_roll_f32(
constant ggml_metal_kargs_roll & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *) src0;
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// apply shifts and wrap around
int64_t i00 = i0 - args.s0;
int64_t i01 = i1 - args.s1;
int64_t i02 = i2 - args.s2;
int64_t i03 = i3 - args.s3;
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
dst_ptr[dst_idx] = src0_ptr[src_idx];
}
}
template <typename T>
kernel void kernel_pad_impl(
constant ggml_metal_kargs_pad & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t k0 = tgpig.x/args.ne1;
const int32_t i1 = tgpig.x - k0*args.ne1;
const int32_t i03 = i3;
const int32_t i02 = i2;
const int32_t i01 = i1;
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
const int32_t i0 = k0*1024 + tpitg.x + l0;
if (i0 >= args.ne0) {
break;
}
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
}
}
}
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
// TODO: this is slow - optimize
kernel void kernel_pad_reflect_1d_f32(
constant ggml_metal_kargs_pad_reflect_1d & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < args.p0) {
dst_ptr[i0] = src0_ptr[args.p0 - i0];
} else if (i0 < args.ne0 - args.p1) {
dst_ptr[i0] = src0_ptr[i0 - args.p0];
} else {
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
}
}
}
}
kernel void kernel_arange_f32(
constant ggml_metal_kargs_arange & args,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_ptr[i0] = args.start + args.step * i0;
}
}
kernel void kernel_timestep_embedding_f32(
constant ggml_metal_kargs_timestep_embedding & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
int i = tgpig.x;
device float * embed_data = (device float *)(dst + i*args.nb1);
int half_ = args.dim / 2;
for (int j = tpitg.x; j < half_; j += ntg.x) {
float timestep = ((device float *)src0)[i];
float freq = (float)exp(-log((float)args.max_period) * j / half_);
float arg = timestep * freq;
embed_data[j ] = cos(arg);
embed_data[j + half_] = sin(arg);
}
if (args.dim % 2 != 0 && tpitg.x == 0) {
embed_data[2 * half_] = 0.f;
}
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_memset & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
template<typename T>
kernel void kernel_count_equal(
constant ggml_metal_kargs_count_equal & args,
device const char * src0,
device const char * src1,
device atomic_int * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const short NSG = FC_count_equal_nsg;
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
int sum = 0;
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
const T v0 = *(device const T *)(base0 + i0*args.nb00);
const T v1 = *(device const T *)(base1 + i0*args.nb10);
sum += (v0 == v1);
}
sum = simd_sum(sum);
if (tiisg == 0) {
shmem_i32[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
float v = 0.0f;
if (tpitg.x < NSG) {
v = shmem_i32[tpitg.x];
}
float total = simd_sum(v);
if (tpitg.x == 0) {
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
}
}
}
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
+838
View File
@@ -0,0 +1,838 @@
#include "common.h"
#include "dequantize.h"
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
// each block_q contains 16*nl weights
#ifdef GGML_METAL_HAS_TENSOR
template<
typename SA, typename SA_4x4, typename SA_8x8,
typename SB, typename SB_2x4, typename SB_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * srcA,
device const char * srcB,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig [[threadgroup_position_in_grid]],
ushort tiitg [[thread_index_in_threadgroup]],
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
(void) sgitg;
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
const int K = args.ne00;
const int M = args.ne0;
const int N = args.ne1;
// Batch dimension handling
const int im = tgpig.z;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
// Batch offsets for srcA and srcB
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
// Tile dimensions
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
// Tile offsets in output matrix
const int ra = tgpig.y * NRA;
const int rb = tgpig.x * NRB;
// Threadgroup memory for dequantized A tile only
threadgroup SA * sa = (threadgroup SA *)(shmem);
// Work-item count for A loading
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
// tA wraps threadgroup memory
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
// tB wraps device memory directly
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
const int strideB = args.nb11 / sizeof(T1);
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
// Configure matmul operation
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
// Accumulate partial results over K dimension
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
// === PHASE 1: Dequantization of A into threadgroup memory ===
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
const int row = work / N_MM_NK;
const int k_chunk = work % N_MM_NK;
const int k_pos = loop_k + k_chunk * 16;
const short k_base = k_chunk * 16;
// Bounds check: skip device read if row is out of matrix bounds
if (ra + row < M) {
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
// Mirrors the legacy kernel's existing guard.
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
}
} else {
const int block_idx = k_pos / (16 * nl);
const short il = (k_pos / 16) % nl;
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
SA_4x4 temp_a;
dequantize_func(row_ptr + block_idx, il, temp_a);
FOR_UNROLL (short i = 0; i < 16; i++) {
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
}
}
} else {
// Zero-pad rows beyond matrix bounds
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// === PHASE 2: Tensor matmul ===
auto mA = tA.slice(0, 0);
auto mB = tB.slice(loop_k, rb);
mm.run(mB, mA, cT);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Store result tile to output matrix (with batch offset)
// cT.store handles bounds checking via tD's extents (M, N)
device float * dstBatch = (device float *)dst + im * N * M;
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
cT.store(tD.slice(ra, rb));
}
#else
template<
typename S0, typename S0_4x4, typename S0_8x8,
typename S1, typename S1_2x4, typename S1_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z;
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*(r1 + lr1)
+ args.nb10*iy);
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
}
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
// if no bounds checks on the output are needed, we can directly write to device memory
device float * C = (device float *) dst +
(r0 + 32*(sgitg & 1)) + \
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
for (int j = tiitg; j < nr1; j += NR1) {
device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*NR0);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = 0;
for (; i < nr0/4; i++) {
*(D4 + i) = *(C4 + i);
}
i *= 4;
for (; i < nr0; i++) {
*(D + i) = *(C + i);
}
}
}
}
}
#endif // GGML_METAL_HAS_TENSOR
template<short ne20> // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
device const char * src2,
device char * htpe,
device char * hids,
threadgroup char * shmem [[threadgroup(0)]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const short ide = tpitg; // expert id
uint32_t n_all = 0;
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
if (i21 + tpitg < args.ne21) {
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sids[i20] = src2_i32[i20];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short t = 0; t < ntg; t++) {
if (i21 + t >= args.ne21) {
break;
}
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
short sel = 0;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sel += (sids[i20] == ide)*(i20 + 1);
}
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
n_all += sel > 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
tpe_u32[ide] = n_all;
}
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
device const char * htpe,
device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z; // expert
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
device const int32_t * ids_i32 = (device const int32_t *) (hids);
const int32_t neh1 = tpe_u32[im];
if (r1 >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int id = ids_i32[im*args.ne21 + r1 + lr1];
const short i11 = (id % args.ne20) % args.ne11;
const short i12 = (id / args.ne20);
const short i13 = 0;
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*i11
+ args.nb10*iy);
#ifndef GGML_METAL_HAS_TENSOR
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
#else
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<4>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
#endif
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
#ifndef GGML_METAL_HAS_TENSOR
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
#else
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
}
#endif
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifndef GGML_METAL_HAS_TENSOR
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
#else
auto sA = tA.slice(0, 0);
auto sB = tB.slice(0, 0);
mm.run(sB, sA, cT);
#endif
}
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifdef GGML_METAL_HAS_TENSOR
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
cT.store(tC);
#else
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
#endif
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short j = sgitg; j < nr1; j += 4) {
const int id = ids_i32[im*args.ne21 + r1 + j];
const short ide = id % args.ne20;
const short idt = id / args.ne20;
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = tiisg;
for (; i < nr0/4; i += 32) {
*(D4 + i) = *(C4 + i);
}
i = (4*(nr0/4)) + tiisg;
for (; i < nr0; i += 32) {
*(D + i) = *(C + i);
}
}
}
//
// matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
//
// indirect matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
File diff suppressed because it is too large Load Diff
+308
View File
@@ -0,0 +1,308 @@
#include "common.h"
// F == 1 : norm (no fuse)
// F == 2 : norm + mul
// F == 3 : norm + mul + add
template <typename T, short F>
kernel void kernel_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
T sumft(0.0f);
float sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumft += x[i00];
}
sumf = dot(sumft, T(1.0f));
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
y[i00] = x[i00] - mean;
sumf += dot(y[i00], y[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float variance = sumf/args.ne00;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (y[i00]*scale);
}
if (F == 2) {
y[i00] = (y[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
// F == 1 : rms_norm (no fuse)
// F == 2 : rms_norm + mul
// F == 3 : rms_norm + mul + add
template <typename T, short F>
kernel void kernel_rms_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (x[i00]*scale);
}
if (F == 2) {
y[i00] = (x[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float scale = 1.0f/max(sqrt(sumf), args.eps);
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
device float * dst,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t ne = args.ne00*args.ne01*args.ne02;
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
int start = tgpig * gs;
int end = start + gs;
start += tpitg;
if (end >= ne) {
end = ne;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int j = start; j < end; j += ntg) {
tmp += src0[j];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float mean = tmp / gs;
tmp = 0.0f;
for (int j = start; j < end; j += ntg) {
float xi = src0[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float variance = tmp / gs;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int j = start; j < end; j += ntg) {
dst[j] *= scale;
}
}
+148
View File
@@ -0,0 +1,148 @@
#include "common.h"
kernel void kernel_pool_2d_max_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
float res = -INFINITY;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
res = MAX(res, i_ptr[i * args.IW + j]);
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_2d_avg_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (args.k0 * args.k1);
float res = 0;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
float cur = i_ptr[i * args.IW + j];
res += cur * scale;
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_1d_max_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = -INFINITY;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
int j = base + ki;
if (j < 0 || j >= args.IW){
continue;
}
float v = src[src_off + j];
acc = max(acc, v);
}
dst[dst_off + ow] = acc;
}
kernel void kernel_pool_1d_avg_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = 0.0f;
int cnt = 0;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
const int j = base + ki;
if (j < 0 || j >= args.IW) {
continue;
}
acc += src[src_off + j];
cnt += 1;
}
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
}
+213
View File
@@ -0,0 +1,213 @@
#pragma once
#include "common.h"
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
float sum_abs = 0.0f;
for (int j = 0; j < QK1_0; j++) {
sum_abs += fabs(src[j]);
}
dst.d = sum_abs / QK1_0;
for (int j = 0; j < QK1_0 / 8; j++) {
dst.qs[j] = 0;
}
for (int j = 0; j < QK1_0; j++) {
if (src[j] >= 0.0f) {
dst.qs[j / 8] |= (1 << (j % 8));
}
}
}
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_0/2 + j]*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
#pragma METAL fp math_mode(safe)
float min = FLT_MAX;
float max = -FLT_MAX;
for (int j = 0; j < QK4_1; j++) {
const float v = src[j];
if (min > v) min = v;
if (max < v) max = v;
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK4_1/2 + j] - min)*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK5_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -16;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK5_0/2 + j]*id;
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
#pragma METAL fp math_mode(safe)
float max = src[0];
float min = src[0];
for (int j = 1; j < QK5_1; j++) {
const float v = src[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
dst.qs[j] = round(x0);
}
}
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_NL; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / kvalues_iq4nl_f[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
dst.qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl_f[xi0];
const float v1 = kvalues_iq4nl_f[xi1];
const float w0 = src[0 + j]*src[0 + j];
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
}
+389
View File
@@ -0,0 +1,389 @@
#include "common.h"
#include "dequantize.h"
#include "quantize.h"
template<typename T0, typename T1>
kernel void kernel_cpy_t_t(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
break;
}
}
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
#endif
template<short QK,
typename block_q,
void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_cpy_f32_q(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
quantize_func(src, dst_data[i00]);
break;
}
}
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_cpy_q_f32(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;
break;
}
}
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
template<typename T>
kernel void kernel_concat(
constant ggml_metal_kargs_concat & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
if (i1 >= args.ne1) {
return;
}
int o[4] = {0, 0, 0, 0};
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
device const T * x;
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
x = (device const T *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
} else {
x = (device const T *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
}
device T * y = (device T *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
*y = *x;
}
}
typedef decltype(kernel_concat<float>) kernel_concat_t;
template [[host_name("kernel_concat_f32")]] kernel kernel_concat_t kernel_concat<float>;
template [[host_name("kernel_concat_f16")]] kernel kernel_concat_t kernel_concat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_concat_bf16")]] kernel kernel_concat_t kernel_concat<bfloat>;
#endif
template [[host_name("kernel_concat_i8")]] kernel kernel_concat_t kernel_concat<char>;
template [[host_name("kernel_concat_i16")]] kernel kernel_concat_t kernel_concat<short>;
template [[host_name("kernel_concat_i32")]] kernel kernel_concat_t kernel_concat<int>;
template [[host_name("kernel_concat_i64")]] kernel kernel_concat_t kernel_concat<long>;
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
float4x4 temp;
dequantize_func(psrc + ind/nl, ind%nl, temp);
pdst[ind] = temp;
break;
}
}
template<typename T0, typename T>
kernel void kernel_get_rows_f(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
pdst[ind] = psrc[ind];
break;
}
}
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_set_rows_q32(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
quantize_func(src_row + 32*ind, dst_row[ind]);
}
}
template<typename T, typename TI>
kernel void kernel_set_rows_f(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
dst_row[ind] = (T) src_row[ind];
}
}
//
// get rows
//
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
#endif
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// set rows
//
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
#endif
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
+228
View File
@@ -0,0 +1,228 @@
#include "common.h"
kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (args.np == 0) {
return;
}
// TODO: become function constant
const uint nsg = (ntg.x + 31) / 32;
float sumf = 0;
for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
}
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
template <typename T0, typename T>
kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args,
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_sum_rows_op
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) {
shmem_t[tiisg] = 0.0f;
}
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_t[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_t[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
if (is_same<float4, T0>::value) {
dst_row[0] = sum(sumf) / (4*args.ne00);
} else {
dst_row[0] = sum(sumf) / args.ne00;
}
} else {
dst_row[0] = sum(sumf);
}
}
#undef FC_OP
}
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T>
kernel void kernel_cumsum_blk(
constant ggml_metal_kargs_cumsum_blk & args,
device const char * src0,
device char * tmp,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 +
args.nb01*i01 +
args.nb02*i02 +
args.nb03*i03);
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
float v = 0.0f;
if (i00 + tpitg.x < args.ne00) {
v = src0_row[i00 + tpitg.x];
}
float s = simd_prefix_inclusive_sum(v);
if (tiisg == N_SIMDWIDTH - 1) {
shmem_f32[sgitg] = s;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
s += shmem_f32[sgitg];
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] = s;
}
if (args.outb && tpitg.x == ntg.x - 1) {
device float * tmp_row = (device float *) tmp +
args.net0*i01 +
args.net0*args.net1*i02 +
args.net0*args.net1*args.net2*i03;
tmp_row[ib] = s;
}
}
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
template<typename T>
kernel void kernel_cumsum_add(
constant ggml_metal_kargs_cumsum_add & args,
device const char * tmp,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
if (ib == 0) {
return;
}
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * tmp_row = (device const float *) (tmp +
args.nbt1*i01 +
args.nbt2*i02 +
args.nbt3*i03);
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
}
}
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
+318
View File
@@ -0,0 +1,318 @@
#include "common.h"
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
constant bool FC_rope_is_back [[function_constant(FC_ROPE + 1)]];
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
thread float * cos_theta, thread float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
*cos_theta = cos(theta) * mscale;
*sin_theta = sin(theta) * mscale;
if (FC_rope_is_back) {
*sin_theta *= -1.0f;
}
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
}
static void rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
}
template<typename T>
kernel void kernel_rope_norm(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_neox(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_multi(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
// mrope theta calculations
// note: the rest is the same as kernel_rope_neox
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (FC_rope_is_imrope) {
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
theta_base = (float) pos[i2 + args.ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
theta_base = (float) pos[i2 + args.ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + args.ne02 * 3];
}
} else {
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
}
// end of mrope
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_vision(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
const int ic = i0/2;
// mrope theta calculations (only support 2 dimensions)
const int sect_dims = args.sect_0 + args.sect_1;
const int sector = ic % sect_dims;
float p;
float theta_base;
if (sector < args.sect_1) {
p = (float) sector;
theta_base = (float) pos[i2];
} else {
p = (float) sector - args.sect_0;
theta_base = (float) pos[i2 + args.ne02];
}
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
+223
View File
@@ -0,0 +1,223 @@
#include "common.h"
template<typename T>
kernel void kernel_soft_max(
constant ggml_metal_kargs_soft_max & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x;
const int32_t i13 = i03%args.ne13;
const int32_t i12 = i02%args.ne12;
const int32_t i11 = i01;
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
// ALiBi
if (args.max_bias > 0.0f) {
const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
float max_val = simd_max(lmax);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
// This barrier fixes a failing test
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
pdst[i00] *= inv_sum;
}
}
template<typename T>
kernel void kernel_soft_max_4(
constant ggml_metal_kargs_soft_max & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x;
const int32_t i13 = i03%args.ne13;
const int32_t i12 = i02%args.ne12;
const int32_t i11 = i01;
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
if (args.max_bias > 0.0f) {
const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max_val = simd_max(lmax);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
// This barrier fixes a failing test
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
pdst4[i00] *= inv_sum;
}
}
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
@@ -0,0 +1,75 @@
#include "common.h"
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
ushort3 tgpig[[threadgroup_position_in_grid]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
constexpr short NW = N_SIMDWIDTH;
const short NSG = FC_solve_tri_nsg;
const short N = FC_solve_tri_n;
const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW);
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg;
threadgroup float * sh0 = (threadgroup float *) shmem;
device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
for (short rr = 0; rr < N; rr += NSG) {
threadgroup_barrier(mem_flags::mem_threadgroup);
{
threadgroup float * sh0_cur = sh0 + sgitg*NP;
for (short t = 0; t*NW < N; ++t) {
const short idx = t*NW + tiisg;
sh0_cur[idx] = src0_ptr[idx];
}
src0_ptr += NSG*N;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (i01 >= args.ne10) {
continue;
}
for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
const short r = rr + ir;
threadgroup float * sh0_cur = sh0 + ir*NP;
float sum = 0.0f;
for (short t = 0; t*NW < r; ++t) {
const short idx = t*NW + tiisg;
sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
}
sum = simd_sum(sum);
if (tiisg == 0) {
const float diag = sh0_cur[r];
dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
}
}
}
}
+279
View File
@@ -0,0 +1,279 @@
#include "common.h"
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
kernel void kernel_ssm_conv_f32_f32(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
x[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
x[0] = sumf;
}
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
// Batched version: each threadgroup processes multiple tokens for better efficiency
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
kernel void kernel_ssm_conv_f32_f32_batched(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;
const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;
const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens
// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}
// Load conv weights (shared across all tokens for this row)
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
// Load source for this specific token
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
x[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_batched_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;
const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;
const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens
// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}
// Load conv weights (shared across all tokens for this row)
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
// Load source for this specific token
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// Optimized version: reduces redundant memory loads by having one thread load shared values
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
constexpr short NW = N_SIMDWIDTH;
// Shared memory layout:
// [0..sgptg*NW-1]: partial sums for reduction (existing)
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
threadgroup float * shared_sums = shared;
threadgroup float * shared_x_dt = shared + sgptg * NW;
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
shared_sums[tpitg.x] = 0.0f;
const int32_t i0 = tpitg.x;
const int32_t i1 = tgpig.x;
const int32_t ir = tgpig.y; // current head
const int32_t i3 = tgpig.z; // current seq
const int32_t nc = args.d_state;
const int32_t nr = args.d_inner;
const int32_t nh = args.n_head;
const int32_t ng = args.n_group;
const int32_t n_t = args.n_seq_tokens;
const int32_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int32_t i = i0 + i1*nc;
const int32_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = 0.0f;
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
const float A0 = A[i0%args.ne30];
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Pre-compute x_dt and dA for this batch of tokens
// Only first sgptg threads do the loads and expensive math
if (i0 < sgptg && i2 + i0 < n_t) {
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
device const float * x_t = x + i0 * args.ns12;
device const float * dt_t = dt + i0 * args.ns21;
const float dt0 = dt_t[0];
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
shared_x_dt[i0] = x_t[0] * dtsp;
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
const float x_dt = shared_x_dt[t];
const float dA = exp(shared_dA[t] * A0);
s = (s0 * dA) + (B[i0] * x_dt);
const float sumf = simd_sum(s * C[i0]);
if (tiisg == 0) {
shared_sums[t*NW + sgitg] = sumf;
}
// recurse
s0 = s;
B += args.ns42;
C += args.ns52;
}
// Advance pointers for next batch
x += sgptg * args.ns12;
dt += sgptg * args.ns21;
threadgroup_barrier(mem_flags::mem_threadgroup);
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
if (tiisg == 0 && i2 + sgitg < n_t) {
y[sgitg*nh*nr] = sumf;
}
y += sgptg*nh*nr;
}
s_buff[i] = s;
}
+69
View File
@@ -0,0 +1,69 @@
#include "common.h"
template<uint32_t ttype>
bool _ggml_vec_tri_cmp(const int i, const int r);
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
return i < r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
return i <= r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
return i > r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
return i >= r;
}
template<typename T, int ttype>
kernel void kernel_tri(
constant ggml_metal_kargs_tri & args,
device const char * src0,
device const char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
// Each thread is a single element of the row if ne00 < max threads per
// threadgroup, so this will loop once for each index that this thread is
// responsible for
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
// Use the comparison as a mask for branchless
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
}
}
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
#endif
+360
View File
@@ -0,0 +1,360 @@
#include "common.h"
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
template <typename T0, typename T, typename TC>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_unary_op
#define FC_CNT FC_unary_cnt
device const T0 * src0_ptr;
device T * dst_ptr;
int i0;
if (FC_CNT) {
i0 = tgpig.x;
src0_ptr = (device const T0 *) (src0);
dst_ptr = (device T *) (dst);
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int k0 = tgpig.x/args.ne01;
const int i01 = tgpig.x - k0*args.ne01;
i0 = k0*ntg.x + tpitg.x;
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
}
{
//threadgroup_barrier(mem_flags::mem_none);
if (!FC_CNT) {
if (i0 >= args.ne0) {
return;
}
}
const TC x = (TC) src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
dst_ptr[i0] = (T) (args.scale * x + args.bias);
}
if (FC_OP == OP_UNARY_NUM_FILL) {
dst_ptr[i0] = (T) args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
dst_ptr[i0] = (T) (x * x);
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
dst_ptr[i0] = (T) sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
dst_ptr[i0] = (T) sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
dst_ptr[i0] = (T) cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
dst_ptr[i0] = (T) log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
}
if (FC_OP == OP_UNARY_NUM_TANH) {
dst_ptr[i0] = (T) precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
dst_ptr[i0] = (T) fmax(0, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
dst_ptr[i0] = (T) elu_approx(x);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
dst_ptr[i0] = (T) -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
dst_ptr[i0] = (T) fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
dst_ptr[i0] = T(x > 0) - T(x < 0);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
dst_ptr[i0] = T(x > 0);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
dst_ptr[i0] = (T) exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
dst_ptr[i0] = (T) (exp(x) - 1);
}
if (FC_OP == OP_UNARY_NUM_FLOOR) {
dst_ptr[i0] = (T) floor(x);
}
if (FC_OP == OP_UNARY_NUM_CEIL) {
dst_ptr[i0] = (T) ceil(x);
}
if (FC_OP == OP_UNARY_NUM_ROUND) {
dst_ptr[i0] = (T) round(x);
}
if (FC_OP == OP_UNARY_NUM_TRUNC) {
dst_ptr[i0] = (T) trunc(x);
}
if (FC_OP == OP_UNARY_NUM_XIELU) {
const TC xi = x;
const TC gate = TC(xi > TC(0.0f));
const TC clamped = fmin(xi, TC(args.val));
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
}
}
#undef FC_OP
#undef FC_CNT
}
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
template<typename T>
kernel void kernel_reglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
}
}
typedef decltype(kernel_reglu<float>) kernel_reglu_t;
template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
template<typename T>
kernel void kernel_geglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
dst_row[i0] = (T)(gelu*x1);
}
}
typedef decltype(kernel_geglu<float>) kernel_geglu_t;
template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
template<typename T>
kernel void kernel_swiglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float silu = x0 / (1.0f + exp(-x0));
dst_row[i0] = (T)(silu*x1);
}
}
typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
template<typename T>
kernel void kernel_swiglu_oai(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
float x0 = src0_row[i0];
float x1 = src1_row[i0];
x0 = min(x0, args.limit);
x1 = max(min(x1, args.limit), -args.limit);
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
out_glu = out_glu * (1.0f + x1);
dst_row[i0] = (T)out_glu;
}
}
typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
template<typename T>
kernel void kernel_geglu_erf(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
dst_row[i0] = (T)(gelu_erf*x1);
}
}
typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
template<typename T>
kernel void kernel_geglu_quick(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = (T)(gelu_quick*x1);
}
}
typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
+179
View File
@@ -0,0 +1,179 @@
#include "common.h"
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
kernel void kernel_upscale_nearest_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3/args.sf3;
const int64_t i02 = i2/args.sf2;
const int64_t i01 = i1/args.sf1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int64_t i00 = i0/args.sf0;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_ptr[0] = src0_ptr[0];
}
}
static inline float bilinear_tri(float x) {
return MAX(0.0f, 1.0f - fabs(x));
}
kernel void kernel_upscale_bilinear_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
src0 += i03*args.nb03 + i02*args.nb02;
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (FC_upscale_aa) {
const float support0 = MAX(1.0f, 1.0f / args.sf0);
const float invscale0 = 1.0f / support0;
const float support1 = MAX(1.0f, 1.0f / args.sf1);
const float invscale1 = 1.0f / support1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
float sum = 0.0f;
float wsum = 0.0f;
for (int64_t sy = y_min; sy < y_max; ++sy) {
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
for (int64_t sx = x_min; sx < x_max; ++sx) {
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
const float w = wx * wy;
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
sum += (*src_ptr) * w;
wsum += w;
}
}
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
dst_ptr[i0] = v;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
const float v =
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
(*src10) * fd0 * (1.0f - fd1) +
(*src01) * (1.0f - fd0) * fd1 +
(*src11) * fd0 * fd1;
dst_ptr[i0] = v;
}
}
}
static inline float bicubic_weight1(float x) {
const float a = -0.75f;
return ((a + 2) * x - (a + 3)) * x * x + 1;
}
static inline float bicubic_weight2(float x) {
const float a = -0.75f;
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
kernel void kernel_upscale_bicubic_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = (int64_t)floor(f01);
const float fd1 = f01 - (float)i01;
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
const float w_y1 = bicubic_weight1(fd1);
const float w_y2 = bicubic_weight1(1.0f - fd1);
const float w_y3 = bicubic_weight2(2.0f - fd1);
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = (int64_t)floor(f00);
const float fd0 = f00 - (float)i00;
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
const float w_x1 = bicubic_weight1(fd0);
const float w_x2 = bicubic_weight1(1.0f - fd0);
const float w_x3 = bicubic_weight2(2.0f - fd0);
float sum = 0.0f;
for (int dy = -1; dy <= 2; ++dy) {
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
for (int dx = -1; dx <= 2; ++dx) {
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
sum += (*src_ptr) * wx * wy;
}
}
dst_ptr[i0] = sum;
}
}
+179
View File
@@ -0,0 +1,179 @@
#include "common.h"
kernel void kernel_rwkv_wkv6_f32(
device const float * k,
device const float * v,
device const float * r,
device const float * tf,
device const float * td,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _k[head_size];
threadgroup float _r[head_size];
threadgroup float _tf[head_size];
threadgroup float _td[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
_tf[tid] = tf[head_id * head_size + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0;
for (uint j = 0; j < head_size; j += 4) {
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
float4 temp = tf_vec * kv + s_vec;
y += dot(r_vec, temp);
s_vec = s_vec * td_vec + kv;
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid] = state[i];
}
}
kernel void kernel_rwkv_wkv7_f32(
device const float * r,
device const float * w,
device const float * k,
device const float * v,
device const float * a,
device const float * b,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _r[head_size];
threadgroup float _w[head_size];
threadgroup float _k[head_size];
threadgroup float _a[head_size];
threadgroup float _b[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i];
}
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0, sa = 0.0;
float4 sa_vec(0.0);
for (uint j = 0; j < head_size; j += 4) {
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
sa_vec += a_vec * s_vec;
}
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
for (uint j = 0; j < head_size; j += 4) {
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
s_vec = s_vec * w_vec + kv + sa * b_vec;
y += dot(s_vec, r_vec);
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i] = state[i];
}
}
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
}
// reduction in local memory, assumes #wave=4
-5
View File
@@ -108,9 +108,6 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_CHECK_RESULTS)
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
# the result-checking path computes a CPU reference graph via
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
if (GGML_VULKAN_DEBUG)
@@ -132,8 +129,6 @@ if (Vulkan_FOUND)
if (GGML_VULKAN_RUN_TESTS)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
+69 -372
View File
@@ -493,20 +493,6 @@ struct vk_conv2d_pipeline_state {
}
};
struct vk_conv3d_pipeline_state {
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
uint32_t aligned;
bool operator<(const vk_conv3d_pipeline_state &b) const {
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
}
};
struct vk_solve_tri_pipeline_state {
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
: N(N), K(K) {}
@@ -699,7 +685,6 @@ struct vk_device_struct {
bool add_rms_fusion;
uint32_t partials_binding_alignment;
uint32_t max_nodes_per_submit;
bool shader_64b_indexing;
@@ -792,7 +777,6 @@ struct vk_device_struct {
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_back_f32;
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
@@ -817,10 +801,14 @@ struct vk_device_struct {
vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sqrt_f32;
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_tri[2];
vk_pipeline pipeline_diag[2];
vk_pipeline pipeline_clamp[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
@@ -852,10 +840,6 @@ struct vk_device_struct {
vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
vk_pipeline pipeline_sqr[2];
vk_pipeline pipeline_sqrt[2];
vk_pipeline pipeline_sin[2];
vk_pipeline pipeline_cos[2];
vk_pipeline pipeline_xielu[2];
vk_pipeline pipeline_neg[2];
vk_pipeline pipeline_tanh[2];
@@ -887,7 +871,7 @@ struct vk_device_struct {
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
vk_pipeline pipeline_leaky_relu[2];
vk_pipeline pipeline_leaky_relu_f32;
vk_pipeline pipeline_silu_back_f32;
vk_pipeline pipeline_diag_mask_inf_f32;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
@@ -940,8 +924,6 @@ struct vk_device_struct {
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
@@ -1687,41 +1669,6 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
}
struct vk_op_conv3d_push_constants {
uint32_t OC;
uint32_t IC;
uint32_t N;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
};
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
}
struct vk_op_conv2d_dw_push_constants {
uint32_t ne;
uint32_t batches;
@@ -4127,35 +4074,19 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#endif
auto const &ggml_vk_mul_mm_spec = [](std::vector<uint32_t> spec, bool aligned) {
spec.push_back(aligned ? 1u : 0u);
return spec;
};
const int mul_mat_id_param_count = 5;
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector<uint32_t> spec, bool aligned, bool mul_mat_id) {
if (mul_mat_id && spec.size() > 5) {
spec.insert(spec.begin() + 5, aligned ? 1u : 0u);
} else {
spec.push_back(aligned ? 1u : 0u);
}
if (mul_mat_id && spec.size() == 6) {
spec.push_back(32);
}
return spec;
};
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@@ -4230,17 +4161,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -4353,32 +4284,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
// bf16 scalar path promotes to f32, no dot2 variant
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
@@ -4543,17 +4474,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l_int[TYPE]) \
@@ -4948,7 +4879,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
@@ -4973,7 +4903,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
@@ -5093,6 +5023,11 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5102,6 +5037,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -5121,12 +5058,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_UNARY(gelu_quick)
CREATE_UNARY(silu)
CREATE_UNARY(relu)
CREATE_UNARY(sqr)
CREATE_UNARY(sqrt)
CREATE_UNARY(sin)
CREATE_UNARY(cos)
CREATE_UNARY(clamp)
CREATE_UNARY(leaky_relu)
CREATE_UNARY(xielu)
CREATE_UNARY(neg)
CREATE_UNARY(tanh)
@@ -5166,6 +5097,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
@@ -5382,7 +5314,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
// conv2d, conv_transpose_2d, conv3d
// conv2d, conv_transpose_2d
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
@@ -5445,8 +5377,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
};
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
// layout. cm1 needs Csh for output, so check before applying cm1 params.
// coopmat1 needs to store the output through shared memory, so check up front
// whether it'll fit and disable it before applying coopmat1 parameters.
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
conv2d_use_cm1 = false;
}
@@ -5538,53 +5470,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
#undef CREATE_CONV
#undef CREATE_CONVS
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
#define CREATE_CONV3D(type_suffix, spv_suffix) \
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
const vk_conv3d_pipeline_state &state = c.first; \
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
spec_constants_cpy.push_back(state.s0); \
spec_constants_cpy.push_back(state.s1); \
spec_constants_cpy.push_back(state.s2); \
spec_constants_cpy.push_back(state.p0); \
spec_constants_cpy.push_back(state.p1); \
spec_constants_cpy.push_back(state.p2); \
spec_constants_cpy.push_back(state.d0); \
spec_constants_cpy.push_back(state.d1); \
spec_constants_cpy.push_back(state.d2); \
spec_constants_cpy.push_back(state.KW); \
spec_constants_cpy.push_back(state.KH); \
spec_constants_cpy.push_back(state.KD); \
spec_constants_cpy.push_back(state.aligned); \
spec_constants_cpy.push_back(conv2d_csh_store); \
spec_constants_cpy.push_back(conv2d_WM); \
spec_constants_cpy.push_back(conv2d_WN); \
ggml_vk_create_pipeline( \
device, c.second, "conv3d" #type_suffix, \
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
}
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
CREATE_CONV3D(_f32, _cm2)
CREATE_CONV3D(_f16_f32, _cm2)
} else
#endif
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (conv2d_use_cm1) {
CREATE_CONV3D(_f32, _cm1)
CREATE_CONV3D(_f16_f32, _cm1)
} else
#endif
if (conv2d_UNROLL) {
CREATE_CONV3D(_f32, _unroll)
CREATE_CONV3D(_f16_f32, _unroll)
} else {
CREATE_CONV3D(_f32, )
CREATE_CONV3D(_f16_f32, )
}
#undef CREATE_CONV3D
}
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -5879,14 +5764,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
// Submit at least every 100 nodes, in case there are workloads without as much matmul.
device->max_nodes_per_submit = 100;
const char* GGML_VK_MAX_NODES_PER_SUBMIT = getenv("GGML_VK_MAX_NODES_PER_SUBMIT");
if (GGML_VK_MAX_NODES_PER_SUBMIT != nullptr) {
uint32_t max_nodes_per_submit = std::stoul(GGML_VK_MAX_NODES_PER_SUBMIT);
device->max_nodes_per_submit = std::max(max_nodes_per_submit, 1u);
}
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -10417,11 +10294,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_get_rows_f32[src0->type];
}
return nullptr;
case GGML_OP_GET_ROWS_BACK:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_get_rows_back_f32;
}
return nullptr;
case GGML_OP_ACC:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_acc_f32;
@@ -10528,27 +10400,23 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_SQR:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sqr_f32;
}
return nullptr;
case GGML_OP_SQRT:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sqrt_f32;
}
return nullptr;
case GGML_OP_SIN:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sin_f32;
}
return nullptr;
case GGML_OP_COS:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cos_f32;
}
return nullptr;
case GGML_OP_LOG:
@@ -10570,9 +10438,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
}
return nullptr;
case GGML_OP_PAD:
@@ -10940,9 +10807,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_LEAKY_RELU:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_leaky_relu_f32;
}
return nullptr;
case GGML_OP_CONV_2D:
@@ -11019,61 +10885,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
}
return nullptr;
case GGML_OP_CONV_3D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
const uint32_t KW = (uint32_t)src0->ne[0];
const uint32_t KH = (uint32_t)src0->ne[1];
const uint32_t KD = (uint32_t)src0->ne[2];
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
const uint32_t CRS = IC * KW * KH * KD;
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
const uint32_t aligned = ((OC % BS_K == 0) &&
(CRS % BS_CRS == 0) &&
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
if (src0->type == GGML_TYPE_F32) {
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
} else if (src0->type == GGML_TYPE_F16) {
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
} else {
return nullptr;
}
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto it = pipelines->find(conv3d_pipeline_state);
if (it != pipelines->end()) {
pipeline = it->second;
} else {
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
}
}
return pipeline;
}
return nullptr;
case GGML_OP_ADD1:
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_add1_f16_f16;
@@ -11324,10 +11135,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
break;
case GGML_OP_GET_ROWS_BACK:
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
break;
case GGML_OP_ARGSORT:
GGML_ASSERT(0);
break;
@@ -11413,21 +11220,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
GGML_ABORT("invalid push constant type for CONV_2D");
}
break;
case GGML_OP_CONV_3D:
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
elements = { pc.OC, NPQ_blocks, 1 };
if (elements[1] > 512) {
elements[2] = CEIL_DIV(elements[1], 512);
elements[1] = 512;
}
} else {
GGML_ABORT("invalid push constant type for CONV_3D");
}
break;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
@@ -11444,7 +11236,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_CLAMP:
case GGML_OP_LEAKY_RELU:
case GGML_OP_PAD:
case GGML_OP_ROLL:
case GGML_OP_REPEAT:
@@ -11589,21 +11380,6 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
});
}
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
});
}
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -12311,10 +12087,8 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p));
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -13344,51 +13118,6 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
}
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float));
vk_op_conv3d_push_constants p{};
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
p.IW = static_cast<uint32_t>(ne10);
p.IH = static_cast<uint32_t>(ne11);
p.ID = static_cast<uint32_t>(ne12);
p.OW = static_cast<uint32_t>(ne0);
p.OH = static_cast<uint32_t>(ne1);
p.OD = static_cast<uint32_t>(ne2);
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
// total input element count must fit in a uint32.
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
}
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
vk_op_conv2d_dw_push_constants p{};
p.ne = ggml_nelements(dst);
@@ -13415,10 +13144,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const float * op_params = (const float *)dst->op_params;
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = op_params[0];
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p));
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
}
#ifdef GGML_VULKAN_RUN_TESTS
@@ -14521,10 +14247,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_GET_ROWS:
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_GET_ROWS_BACK:
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_ADD:
if (ctx->num_additional_fused_ops) {
@@ -14793,10 +14515,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_3D:
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_CONV_2D_DW:
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
@@ -16182,6 +15900,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
// (and scaled down based on model size, so smaller models submit earlier).
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
int nodes_per_submit = 100;
int submitted_nodes = 0;
int submit_count = 0;
uint64_t mul_mat_bytes = 0;
@@ -16407,7 +16127,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) ||
bool submit = (submitted_nodes >= nodes_per_submit) ||
(mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(i + ctx->num_additional_fused_ops >= last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending);
@@ -17244,8 +16964,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
}
case GGML_OP_GET_ROWS_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SET_ROWS:
{
switch (op->type) {
@@ -17342,11 +17060,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_TRANSPOSE:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) &&
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -17365,9 +17084,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LEAKY_RELU:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->type == op->src[0]->type;
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
@@ -17567,13 +17285,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op));
}
case GGML_OP_CONV_3D:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op);
default:
return false;
}
@@ -18417,20 +18128,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const int32_t d0 = tensor->op_params[4];
const int32_t d1 = tensor->op_params[5];
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
} else if (tensor->op == GGML_OP_CONV_3D) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t s2 = tensor->op_params[2];
const int32_t p0 = tensor->op_params[3];
const int32_t p1 = tensor->op_params[4];
const int32_t p2 = tensor->op_params[5];
const int32_t d0 = tensor->op_params[6];
const int32_t d1 = tensor->op_params[7];
const int32_t d2 = tensor->op_params[8];
const int32_t IC = tensor->op_params[9];
const int32_t N = tensor->op_params[10];
const int32_t OC = tensor->op_params[11];
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}
@@ -1,431 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#include "types.glsl"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, KD, IC*OC]
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, OD, OC*N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t OC;
uint32_t IC;
uint32_t N;
// Tensor spatial sizes: input, output
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t OW;
uint32_t OH;
uint32_t OD;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
uint32_t OWOHODmp; uint32_t OWOHODL;
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint SHMEM_PAD = 4;
// Stride, padding, dilation
layout(constant_id = 6) const uint s0 = 1;
layout(constant_id = 7) const uint s1 = 1;
layout(constant_id = 8) const uint s2 = 1;
layout(constant_id = 9) const uint p0 = 0;
layout(constant_id = 10) const uint p1 = 0;
layout(constant_id = 11) const uint p2 = 0;
layout(constant_id = 12) const uint d0 = 1;
layout(constant_id = 13) const uint d1 = 1;
layout(constant_id = 14) const uint d2 = 1;
// Kernel spatial sizes
layout(constant_id = 15) const uint KW = 1;
layout(constant_id = 16) const uint KH = 1;
layout(constant_id = 17) const uint KD = 1;
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
layout(constant_id = 18) const uint aligned = 0;
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
layout(constant_id = 19) const uint csh_store = 0;
#ifdef COOPMAT
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
layout(constant_id = 20) const uint WM = 32;
layout(constant_id = 21) const uint WN = 32;
const uint TM = 16;
const uint TN = 16;
const uint TK = 16;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint warps_M = BS_K / WM;
const uint warps_N = BS_NPQ / WN;
#endif
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.OC;
uint32_t CRS = p.IC * KD * KH * KW;
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#if defined(COOPMAT2) || defined(COOPMAT)
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
const uint32_t Csh_stride = BS_NPQ;
#ifdef COOPMAT
const uint32_t Csh_len = BS_K * Csh_stride;
#else
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
#endif
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
#endif
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=OC
C=IC
D,R,S=KD,KH,KW
Z,P,Q=OD,OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
const uint32_t KHKW = KH * KW;
const uint32_t KDKHKW = KD * KHKW;
ic = crs_idx / KDKHKW;
uint32_t rem = crs_idx - ic * KDKHKW;
kd = rem / KHKW;
rem = rem - kd * KHKW;
kh = rem / KW;
kw = rem - kh * KW;
}
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
const uint32_t OWOH = p.OW * p.OH;
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
uint32_t rem = npq_idx - n * p.OD * OWOH;
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
rem = rem - od * OWOH;
oh = fastdiv(rem, p.OWmp, p.OWL);
ow = rem - oh * p.OW;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
if (B_idx_NPQ * BS_NPQ >= NPQ) {
return;
}
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#elif defined(COOPMAT)
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
}
const uint warp_r = gl_SubgroupID / warps_N;
const uint warp_c = gl_SubgroupID % warps_N;
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
uint32_t IC_idx_a;
uint32_t KD_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
/* Load kernel to A_block: (BS_K x BS_CRS)*/
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
if (aligned == 0) {
knl_idx = min(knl_idx, K * CRS - 1);
}
float val = knl_data[knl_idx];
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
uint32_t IC_idx_b;
uint32_t KD_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
// skip clamp when address can't go OOB
if (aligned == 0 || !dhw_in_bounds) {
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
}
float val = src_data[src_idx];
bool oob = false;
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
oob = true;
}
// also catches lower-bound underflow (idx wraps to 0x80000000+)
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
oob = true;
}
if (oob) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#elif defined(COOPMAT)
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
#if defined(COOPMAT2) || defined(COOPMAT)
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
#ifdef COOPMAT
const bool use_staged_store = true;
#else
const bool use_staged_store = (csh_store != 0);
#endif
if (use_staged_store) {
#ifdef COOPMAT
// cm1: each subgroup stores its fragment grid into its Csh slot
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
}
#else
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
#endif
barrier();
// cooperative shmem->global: WG threads spread across BS_NPQ (the
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
const uint32_t store_iters = BS_K / store_rows_per_iter;
const uint32_t k_thread_offset = tid / BS_NPQ;
const uint32_t npq_thread = tid % BS_NPQ;
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
uint32_t K_idx = B_idx_K * BS_K + k_local;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
}
}
}
#ifdef COOPMAT2
else {
coopMatPerElementNV(matC, matC, perElemOpStore);
}
#endif
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx;
uint32_t OD_idx;
uint32_t OH_idx;
uint32_t OW_idx;
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
}
}
}
}
#endif
}
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
}
@@ -463,7 +463,6 @@ void main() {
}
rowmaxf = max(rowmaxf, float(Sf[r][c]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// M = max(rowmax, Mold)
@@ -352,7 +352,6 @@ void main() {
}
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
}
rowmaxf += FATTN_KQ_MAX_OFFSET;
float Moldf = Mf[r];
// Compute max across the row
@@ -1,25 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint col = gl_GlobalInvocationID.x;
if (col >= p.ne20) {
return;
}
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
float sum = 0.0f;
for (uint i = 0; i < p.ne10; ++i) {
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
}
}
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
}
}
@@ -14,13 +14,16 @@ void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
const uint i3 = row / (p.ne11 * p.ne12);
const uint i3_offset = i3 * p.ne12 * p.ne11;
const uint i2 = (row - i3_offset) / p.ne11;
const uint i2_offset = i2 * p.ne11;
const uint i1 = row - i3_offset - i2_offset;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]);
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
sum[tid] += xi * xi;
}
@@ -36,6 +39,6 @@ void main() {
const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00]));
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
}
}
@@ -0,0 +1,22 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float val = float(data_a[i]);
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
}
+23 -31
View File
@@ -38,7 +38,17 @@
#define LOAD_VEC_B 1
#endif
layout (constant_id = 11) const uint ALIGNED = 0;
// Load 2 values at once without affecting index calculations through LOAD_VEC
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
#define LOAD_VEC_BATCH_A 2
#else
#define LOAD_VEC_BATCH_A 1
#endif
#if !defined(ALIGNED)
#define LOAD_VEC_BATCH_B 2
#else
#define LOAD_VEC_BATCH_B 1
#endif
#if !defined(TO_FLOAT_TYPE)
#define TO_FLOAT_TYPE FLOAT_TYPE
@@ -47,13 +57,6 @@ layout (constant_id = 11) const uint ALIGNED = 0;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(DATA_A_F32)
layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];};
#elif defined(DATA_A_F16)
layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];};
#elif defined(DATA_A_BF16)
layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];};
#endif
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
@@ -62,7 +65,6 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
@@ -192,23 +194,13 @@ void main() {
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1;
const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2;
#else
const uint LOAD_VEC_A_EFF = LOAD_VEC_A;
const uint LOAD_VEC_BATCH_A = 1;
#endif
const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1;
const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2;
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK;
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS
@@ -247,15 +239,15 @@ void main() {
uint pos_a =
#ifdef MUL_MAT_ID
expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) +
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
#else
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) +
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
#endif
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF;
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
#ifdef MUL_MAT_ID
uint pos_b = 0;
#else
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF;
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
#endif
#ifdef COOPMAT
@@ -295,8 +287,8 @@ void main() {
barrier();
pos_a += BK / LOAD_VEC_A_EFF;
pos_b += BK / LOAD_VEC_B_EFF;
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
#ifdef COOPMAT
[[unroll]] for (uint i = 0; i < BK; i += TK) {
@@ -36,7 +36,6 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
layout (constant_id = 4) const bool enable_smaller_matrices = false;
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
layout (constant_id = 5) const uint ALIGNED = 0;
layout (push_constant) uniform parameter
{
@@ -112,7 +111,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
};
uint _ne1;
layout (constant_id = 6) const uint subgroup_size = 32;
layout (constant_id = 5) const uint subgroup_size = 32;
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
@@ -298,12 +297,12 @@ void main() {
// Hint to the compiler that values are aligned (want 16B alignment).
// Quants are always block-aligned, no alignment needed.
if (ALIGNED != 0) {
#if ALIGNED
#if QUANT_K == 1
stride_a &= ~7;
stride_a &= ~7;
#endif
stride_b &= ~7;
#endif
stride_b &= ~7;
}
// Create layouts for both clamped and unclamped accesses
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
@@ -1,57 +1,50 @@
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
#if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
return;
}
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
buf_a[buf_idx + 3] = aa[1].zw;
#elif LOAD_VEC_A == 4
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx],
data_a_scalar[idx + 1]);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
data_a[idx + 1]);
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_BF16)
#if LOAD_VEC_A == 4
if (ALIGNED != 0) {
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
return;
}
#endif
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]),
TO_FLOAT_TYPE(data_a_scalar[idx + 1]));
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
TO_FLOAT_TYPE(data_a[idx + 1]));
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -533,85 +526,75 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#if !defined(MUL_MAT_ID)
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
#elif LOAD_VEC_B == 4
if (ALIGNED != 0) {
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint idx = pos_b + col * p.stride_b + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
} else if (idx_n < p.N && block + row * 2 < end_k) {
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#else
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
#if LOAD_VEC_B == 8
if (ALIGNED != 0) {
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
return;
}
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
buf_b[buf_idx + 3] = bb[1].zw;
#elif LOAD_VEC_B == 4
if (ALIGNED != 0) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
return;
}
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
#else // LOAD_VEC_BATCH_B == 2
const uint row_i = ic * BN + col;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
} else if (row_i < _ne1 && block + row * 2 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
#endif
+10 -10
View File
@@ -1,26 +1,26 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared vec2 sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
sum[tid] = vec2(0.0f, 0.0f);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
const float xi = float(data_a[a_base + i0*p.nb00]);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const float xi = float(data_a[row*p.KX + col]);
sum[tid].x += xi;
sum[tid].y += xi * xi;
}
@@ -34,11 +34,11 @@ void main() {
barrier();
}
const float mean = sum[0].x / p.ne00;
const float var = sum[0].y / p.ne00 - mean * mean;
const float mean = sum[0].x / p.KX;
const float var = sum[0].y / p.KX - mean * mean;
const float inv_std = inversesqrt(var + p.param1);
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std);
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
}
}
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
}
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
}
@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
}
@@ -17,30 +17,6 @@ float op_neg(float x) {
return -x;
}
float op_sqr(float x) {
return x * x;
}
float op_sqrt(float x) {
return sqrt(x);
}
float op_sin(float x) {
return sin(x);
}
float op_cos(float x) {
return cos(x);
}
float op_clamp(float x) {
return clamp(x, p.param1, p.param2);
}
float op_leaky_relu(float x) {
return max(x, 0.0f) + min(x, 0.0f) * p.param1;
}
float op_step(float x) {
return x >= 0.0f ? 1.0f : 0.0f;
}
@@ -11,7 +11,6 @@
#include <future>
#include <queue>
#include <condition_variable>
#include <atomic>
#include <cstdio>
#include <cstring>
#include <cstdlib>
@@ -35,9 +34,6 @@
std::mutex lock;
std::vector<std::pair<std::string, std::string>> shader_fnames;
// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the
// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393)
static std::atomic<bool> compile_failed{false};
std::locale c_locale("C");
std::string GLSLC = "glslc";
@@ -82,7 +78,7 @@ enum MatMulIdType {
namespace {
int execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
HANDLE stderr_read, stderr_write;
@@ -131,11 +127,8 @@ int execute_command(std::vector<std::string>& command, std::string& stdout_str,
CloseHandle(stdout_read);
CloseHandle(stderr_read);
WaitForSingleObject(pi.hProcess, INFINITE);
DWORD exit_code = 1;
GetExitCodeProcess(pi.hProcess, &exit_code);
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);
return (int)exit_code;
#else
int stdout_pipe[2];
int stderr_pipe[2];
@@ -182,9 +175,7 @@ int execute_command(std::vector<std::string>& command, std::string& stdout_str,
close(stdout_pipe[0]);
close(stderr_pipe[0]);
int status = 0;
waitpid(pid, &status, 0);
return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
waitpid(pid, nullptr, 0);
}
#endif
}
@@ -381,14 +372,13 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// }
// std::cout << std::endl;
int exit_code = execute_command(cmd, stdout_str, stderr_str);
if (exit_code != 0 || !stderr_str.empty()) {
std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n";
execute_command(cmd, stdout_str, stderr_str);
if (!stderr_str.empty()) {
std::cerr << "cannot compile " << name << "\n\n";
for (const auto& part : cmd) {
std::cerr << part << " ";
}
std::cerr << "\n\n" << stderr_str << std::endl;
compile_failed = true;
return;
}
@@ -408,7 +398,6 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
shader_fnames.push_back(std::make_pair(name, out_path));
} catch (const std::exception& e) {
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
compile_failed = true;
}
}
@@ -550,9 +539,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};
// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
@@ -574,7 +565,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
#endif
{
if (!dot2) {
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
}
@@ -591,6 +583,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
// For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
@@ -603,11 +597,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
@@ -854,12 +850,21 @@ void process_shaders() {
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}});
@@ -886,18 +891,6 @@ void process_shaders() {
string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}});
string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}});
string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}});
string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}});
string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}});
string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}});
string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}});
string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}});
string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}});
string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}});
string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}});
string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}});
string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}});
string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}});
string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}});
string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}});
string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}});
string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}});
@@ -955,6 +948,7 @@ void process_shaders() {
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -1066,31 +1060,6 @@ void process_shaders() {
}
}
for (auto unroll : {false, true}) {
for (auto a_f16 : {false, true}) {
std::map<std::string, std::string> defines = {
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
{"UNROLL", unroll ? "[[unroll]]" : ""},
};
std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32";
string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines);
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (unroll) {
auto cm2_defines = defines;
cm2_defines["COOPMAT2"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (unroll) {
auto cm1_defines = defines;
cm1_defines["COOPMAT"] = "1";
string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false);
}
#endif
}
}
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
@@ -1282,11 +1251,6 @@ int main(int argc, char** argv) {
process_shaders();
if (compile_failed) {
std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl;
return EXIT_FAILURE;
}
write_output_files();
return EXIT_SUCCESS;
@@ -905,12 +905,11 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
int vectorized;
uint32_t num_cols;
bool use_mmvq;
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
use_mmvq == other.use_mmvq;
}
};
@@ -920,7 +919,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.use_mmvq);
return seed;
}
@@ -995,12 +993,11 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
ggml_type src0_type;
ggml_type src1_type;
uint32_t n_experts;
uint32_t num_cols;
int vectorized;
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
num_cols == other.num_cols && vectorized == other.vectorized;
vectorized == other.vectorized;
}
};
@@ -1010,7 +1007,6 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.src0_type);
ggml_webgpu_hash_combine(seed, key.src1_type);
ggml_webgpu_hash_combine(seed, key.n_experts);
ggml_webgpu_hash_combine(seed, key.num_cols);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
@@ -1111,7 +1107,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
const ggml_tensor * src1,
bool supports_dot_product,
const std::string & vendor) {
if (src1->ne[1] <= 4) {
if (src1->ne[1] == 1) {
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
if (supports_dp4a && supports_dot_product) {
switch (src1->type) {
@@ -1893,7 +1889,6 @@ class ggml_webgpu_shader_lib {
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.num_cols = context.dst->ne[1];
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
@@ -2009,7 +2004,6 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
auto processed = preprocessor.preprocess(shader_src, defines);
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
@@ -2427,7 +2421,6 @@ class ggml_webgpu_shader_lib {
if (key.vectorized) {
variant += "_vectorized";
}
defines.push_back(std::string("NUM_COLS=1"));
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
+10 -12
View File
@@ -1418,17 +1418,15 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
const size_t q8_src1_align_offset = ROUNDUP_POW2(
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t q8_src1_binding_size = ROUNDUP_POW2(
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
const size_t q8_src1_binding_size =
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
WEBGPU_STORAGE_BUF_BINDING_MULT);
std::vector<uint32_t> q8_params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
@@ -1444,7 +1442,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
uint32_t q8_wg_x = 1;
uint32_t q8_wg_y = 1;
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
@@ -1458,7 +1456,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * dst) {
// Determine if this is a mat-vec operation
bool use_mat_vec = (dst->ne[1] <= 4);
bool is_vec = (dst->ne[1] == 1);
// use MMVQ path for mat-vec
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
@@ -1484,7 +1482,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
webgpu_pipeline pipeline;
std::vector<webgpu_dispatch_desc> dispatches;
if (use_mat_vec) {
if (is_vec) {
if (use_mmvq) {
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
}
@@ -1531,7 +1529,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_mat_vec) {
if (is_vec) {
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
@@ -3693,8 +3691,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
ctx->webgpu_global_ctx->vendor);
if (use_mmvq) {
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
const size_t q8_src1_size =
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
res = ROUNDUP_POW2(res + q8_src1_size +
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
WEBGPU_STORAGE_BUF_BINDING_MULT);
@@ -4270,7 +4268,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_RMS_NORM:
case GGML_OP_NORM:
case GGML_OP_L2_NORM:
supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0);
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_ROPE:
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
@@ -103,7 +103,7 @@ fn main(
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[0][row]);
let subgroup_total = subgroupAdd(acc[row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
@@ -126,7 +126,7 @@ fn main(
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[0][row];
partial_sums[partial_index(row, thread_id)] = acc[row];
}
workgroupBarrier();
@@ -91,67 +91,61 @@ fn main(
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
#ifdef MMVQ
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
#else
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
#endif
for (var col = 0u;col < NUM_COLS;col += 1) {
#ifdef USE_SUBGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[col][row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
workgroupBarrier();
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + col * params.m + row] = row_total;
}
}
#endif
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[col][row];
}
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let subgroup_total = subgroupAdd(acc[row]);
if (subgroup_invocation_id == 0u) {
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
}
}
workgroupBarrier();
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
let output_row = row_base + row;
var row_acc = 0.0f;
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
row_acc += partial_sums[partial_index(row, k)];
}
let row_total = subgroupAdd(row_acc);
if (subgroup_invocation_id == 0) {
dst[dst_idx_base + row] = row_total;
}
}
#endif
#ifdef USE_WORKGROUP_REDUCTION
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] = acc[row];
}
workgroupBarrier();
var stride = WG_SIZE / 2u;
while (stride > 0) {
if (thread_id < stride) {
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
}
}
workgroupBarrier();
stride = stride / 2;
}
if (thread_id < OUTPUTS_PER_WG) {
let output_row = row_base + thread_id;
if (output_row < params.m) {
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
}
}
#endif
}
File diff suppressed because it is too large Load Diff
@@ -51,7 +51,10 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
#endif // MUL_ACC_Q4_0
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#ifdef MUL_ACC_Q4_1
#define BLOCK_SIZE_BYTES 20
@@ -82,7 +85,10 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
f32(load_f16_at_src0(block_byte_base + 2u))
);
}
#endif // MUL_ACC_Q4_1
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
}
#endif
#ifdef MUL_ACC_Q8_0
#define BLOCK_SIZE_BYTES 34
@@ -105,48 +111,46 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
fn get_dm(block_byte_base: u32) -> f32 {
return f32(load_f16_at_src0(block_byte_base));
}
#endif // MUL_ACC_Q8_0
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
return f32(row_sum) * (da * b_ds);
}
#endif
#if defined(LEGACY_QUANTS)
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
#ifdef LEGACY_QUANTS
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
var row_sum = 0;
let a_repacked = repack_a(a_byte_base, b_inner_id);
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
}
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
let num_blocks = params.k / BLOCK_SIZE;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let inner_id = thread_id % THREADS_PER_BLOCK;
let b_inner_id = thread_id % THREADS_PER_BLOCK;
let b_block_idx = src1q_idx_base + block;
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
let b_ds = repack_b_dm(b_block_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let a_repacked = repack_a(block_byte_base, inner_id);
let da = get_dm(block_byte_base);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
let b_repacked = repack_b_qs(src1q_idx, inner_id);
let b_ds = repack_b_dm(src1q_idx);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
#if defined(MUL_ACC_Q4_0)
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_0
#if defined(MUL_ACC_Q4_1)
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
#endif // MUL_ACC_Q4_1
#if defined(MUL_ACC_Q8_0)
acc[col][row] += f32(row_sum) * (da * b_ds);
#endif // MUL_ACC_Q8_0
}
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
}
}
}
return acc;
}
#endif // LEGACY_QUANTS
#endif
#ifdef MUL_ACC_Q2_K
#define BLOCK_SIZE_BYTES 84
@@ -187,7 +191,22 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
}
#endif // MUL_ACC_Q2_K
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
}
#endif
#ifdef MUL_ACC_Q4_K
#define BLOCK_SIZE_BYTES 144
@@ -246,52 +265,39 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
return vec2<f32>(scale, min_val);
}
#endif // MUL_ACC_Q4_K
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
let a_repacked = repack_a(a_byte_base, tid);
let dm = get_dm(a_byte_base);
let scale_min = get_scale_min(a_byte_base, tid);
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
}
#endif
#ifdef K_QUANTS
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
var acc: array<f32, OUTPUTS_PER_WG>;
let tid = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let a_repacked = repack_a(block_byte_base, tid);
let dm = get_dm(block_byte_base);
let scale_min = get_scale_min(block_byte_base, tid);
for (var col = 0u;col < NUM_COLS;col += 1) {
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
let b_repacked = repack_b_qs(src1q_idx, tid);
let b_ds = repack_b_dm(src1q_idx);
#if defined(MUL_ACC_Q2_K)
let scale_q = i32(scale_min.x);
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
#endif // MUL_ACC_Q2_K
#if defined(MUL_ACC_Q4_K)
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
#endif // MUL_ACC_Q4_K
}
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
}
}
}
return acc;
}
#endif // K_QUANTS
#endif
@@ -9,11 +9,9 @@ requires packed_4x8_integer_dot_product;
struct Params {
offset_src1: u32,
stride_11: u32,
stride_12: u32,
stride_13: u32,
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
};
@@ -59,28 +57,25 @@ fn main(
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let thread_id = local_id.x;
let ne0_vec4 = params.ne0 / 4u;
let num_vec4 = params.ne0 / 4u;
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
let total_batches = wg_per_vec * params.ne2 * params.ne3;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
if (wg_linear >= total_batches) {
return;
}
let vec_idx = wg_linear / wg_per_vec;
let src13_idx = vec_idx / (params.ne2 * params.ne1);
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
let src12_idx = vec_ne12_num / params.ne1;
let src11_idx = vec_ne12_num % params.ne1;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
let src1_idx_vec4_base = src1_idx_base / 4u;
let blocks_per_row = params.ne0 / 32u;
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
let src11_wg_idx = wg_linear % wg_per_vec;
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
let qs_idx = thread_id % 8u;
@@ -90,7 +85,7 @@ fn main(
var thread_amax = 0.0;
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
let is_valid = src11_vec4_idx < ne0_vec4;
let is_valid = src11_vec4_idx < num_vec4;
#ifdef USE_SUBGROUP_REDUCTION
-1
View File
@@ -359,7 +359,6 @@ class Keys:
CHUNK_SIZE = "clip.audio.chunk_size"
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
MAX_POS_EMB = "clip.audio.max_pos_emb"
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
class Attention:
HEAD_COUNT = "clip.audio.attention.head_count"
-3
View File
@@ -1310,9 +1310,6 @@ class GGUFWriter:
def add_audio_max_pos_emb(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
def add_audio_projector_window_size(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
+1 -7
View File
@@ -57,25 +57,19 @@ oppoll=
opflt=
[ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF"
opfuse=
[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC"
vmem=
[ "$VM" != "" ] && vmem="GGML_HEXAGON_VMEM=$VM"
mbuf=
[ "$MB" != "" ] && mbuf="GGML_HEXAGON_MBUF=$MB"
mmsel=
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
set -x
adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $vmem $mbuf \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 1024 -fa on \

Some files were not shown because too many files have changed in this diff Show More