mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-24 14:47:39 +02:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bec3083830 |
+4
-10
@@ -396,7 +396,7 @@ static bool parse_bool_value(const std::string & value) {
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
@@ -408,10 +408,6 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex,
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
|
||||
if (callback) {
|
||||
opts.callback = callback;
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
// so we should not auto-discover mtp/mmproj siblings for them
|
||||
common_download_opts sub_opts = opts;
|
||||
@@ -588,11 +584,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||
}
|
||||
|
||||
const bool skip_model_download =
|
||||
// server will call common_params_handle_models() later, so we skip it here
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
|
||||
// export_graph_ops loads only metadata
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
@@ -601,6 +594,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty()
|
||||
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
|
||||
+1
-5
@@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
@@ -134,10 +133,7 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
common_download_callback * callback = nullptr);
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
+53
-103
@@ -90,93 +90,41 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
|
||||
return text;
|
||||
}
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role) {
|
||||
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
|
||||
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
|
||||
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
|
||||
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
|
||||
return COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
|
||||
const char * common_chat_role_to_string(common_chat_role role) {
|
||||
switch (role) {
|
||||
case COMMON_CHAT_ROLE_SYSTEM: return "system";
|
||||
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
|
||||
case COMMON_CHAT_ROLE_USER: return "user";
|
||||
case COMMON_CHAT_ROLE_TOOL: return "tool";
|
||||
case COMMON_CHAT_ROLE_UNKNOWN: return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
json common_chat_msg_delimiters::to_json() const {
|
||||
json result = json::array();
|
||||
for (const auto & d : delimiters) {
|
||||
result.push_back({
|
||||
{ "role", common_chat_role_to_string(d.role) },
|
||||
{ "delimiter", d.delimiter },
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
|
||||
common_chat_msg_delimiters result;
|
||||
|
||||
if (!delimiters.is_array()) {
|
||||
return result;
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
|
||||
if (delims.empty() || prompt.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
result.delimiters.reserve(delimiters.size());
|
||||
for (const auto & d : delimiters) {
|
||||
if (!d.is_object()) {
|
||||
continue;
|
||||
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
|
||||
std::vector<std::string> all_delims;
|
||||
std::vector<common_peg_parser> tagged_messages;
|
||||
|
||||
all_delims.reserve(delims.size());
|
||||
tagged_messages.reserve(delims.size());
|
||||
for (const auto & d : delims) {
|
||||
all_delims.push_back(d.delimiter);
|
||||
}
|
||||
result.delimiters.push_back({
|
||||
common_chat_role_from_string(d.value("role", std::string())),
|
||||
d.value("delimiter", std::string()),
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
|
||||
for (auto & d : delimiters) {
|
||||
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
|
||||
std::vector<std::pair<common_chat_role, size_t>> matches;
|
||||
|
||||
auto skip = skips.begin();
|
||||
for (size_t i = 0; i < tokens.size();) {
|
||||
if (skip != skips.end() && i == skip->first) {
|
||||
i += skip->second;
|
||||
++skip;
|
||||
continue;
|
||||
auto any_delim = p.until_one_of(all_delims);
|
||||
for (const auto & d : delims) {
|
||||
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
|
||||
}
|
||||
for (const auto & d : delimiters) {
|
||||
if (i + d.tokens.size() > tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
|
||||
matches.emplace_back(d.role, i);
|
||||
break;
|
||||
}
|
||||
|
||||
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(prompt);
|
||||
const auto result = parser.parse(ctx);
|
||||
if (!result.success()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
if (!node.tag.empty()) {
|
||||
spans.push_back({ node.tag, node.start, node.end - node.start });
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
|
||||
|
||||
common_chat_msg_spans spans;
|
||||
for (size_t i = 0; i + 1 < matches.size(); i++) {
|
||||
const auto & curr = matches[i];
|
||||
const auto & next = matches[i + 1];
|
||||
spans.add(curr.first, curr.second, next.second - curr.second);
|
||||
}
|
||||
});
|
||||
|
||||
return spans;
|
||||
}
|
||||
@@ -1133,13 +1081,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
|
||||
};
|
||||
data.message_spans = common_chat_split_by_role(prompt, {
|
||||
{ "assistant", "<|start|>assistant" },
|
||||
{ "user", "<|start|>user" },
|
||||
{ "system", "<|start|>developer" },
|
||||
{ "system", "<|start|>system" },
|
||||
{ "tool", "<|start|>functions" },
|
||||
});
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
@@ -1280,10 +1228,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
|
||||
};
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "user", "<|turn>user\n" },
|
||||
{ "assistant", "<|turn>model\n" },
|
||||
});
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
@@ -2082,15 +2030,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
|
||||
RESULT_START, RESULT_END,
|
||||
};
|
||||
|
||||
// Declare per-role message delimiters. Tool results are rendered with the
|
||||
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
|
||||
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
|
||||
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
|
||||
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
|
||||
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
|
||||
};
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "assistant", GEN_PREFIX },
|
||||
{ "user", TURN_START + USER },
|
||||
{ "tool", TURN_START + SYSTEM + RESULT_START },
|
||||
{ "system", TURN_START + SYSTEM },
|
||||
});
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
@@ -2578,15 +2526,17 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
||||
common_chat_msg_delimiters delimiters;
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
if (!autoparser.assistant_start.empty()) {
|
||||
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
|
||||
delimiters.push_back({ "assistant", autoparser.assistant_start });
|
||||
}
|
||||
if (!autoparser.user_start.empty()) {
|
||||
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
|
||||
delimiters.push_back({ "user", autoparser.user_start });
|
||||
}
|
||||
|
||||
auto_params.message_delimiters = std::move(delimiters);
|
||||
if (!delimiters.empty()) {
|
||||
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
|
||||
}
|
||||
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
|
||||
+6
-65
@@ -143,75 +143,15 @@ struct common_chat_msg_diff {
|
||||
}
|
||||
};
|
||||
|
||||
enum common_chat_role {
|
||||
COMMON_CHAT_ROLE_UNKNOWN,
|
||||
COMMON_CHAT_ROLE_SYSTEM,
|
||||
COMMON_CHAT_ROLE_ASSISTANT,
|
||||
COMMON_CHAT_ROLE_USER,
|
||||
COMMON_CHAT_ROLE_TOOL
|
||||
};
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role);
|
||||
const char * common_chat_role_to_string(common_chat_role role);
|
||||
|
||||
struct common_chat_msg_span {
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string role;
|
||||
std::size_t pos = 0;
|
||||
std::size_t len = 0;
|
||||
|
||||
bool valid() const {
|
||||
return role != COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_spans {
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
|
||||
void add(common_chat_role role, size_t pos, size_t len) {
|
||||
spans.push_back({ role, pos, len });
|
||||
}
|
||||
|
||||
bool is_user_start(int32_t pos) const {
|
||||
for (auto it = spans.begin(); it != spans.end(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t last_user_message_pos() const {
|
||||
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER) {
|
||||
return (int32_t) it->pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiter {
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string delimiter;
|
||||
llama_tokens tokens = {};
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiters {
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
|
||||
common_chat_msg_delimiters() = default;
|
||||
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
|
||||
|
||||
void add(common_chat_role role, const std::string & delimiter) {
|
||||
delimiters.push_back({ role, delimiter });
|
||||
}
|
||||
|
||||
void tokenize(const llama_vocab * vocab);
|
||||
|
||||
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
|
||||
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
|
||||
|
||||
nlohmann::ordered_json to_json() const;
|
||||
std::string role;
|
||||
std::string delimiter;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -279,7 +219,7 @@ struct common_chat_params {
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
common_chat_msg_delimiters message_delimiters;
|
||||
std::vector<common_chat_msg_span> message_spans;
|
||||
};
|
||||
|
||||
// per-message parsing syntax
|
||||
@@ -385,4 +325,5 @@ struct common_chat_prompt_preset {
|
||||
|
||||
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
|
||||
|
||||
|
||||
+1
-1
@@ -609,7 +609,7 @@ struct common_params {
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
|
||||
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
|
||||
@@ -341,9 +341,6 @@ set(GGML_PUBLIC_HEADERS
|
||||
include/gguf.h)
|
||||
|
||||
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
|
||||
#if (GGML_METAL)
|
||||
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
|
||||
#endif()
|
||||
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
|
||||
install(TARGETS ggml-base LIBRARY)
|
||||
|
||||
|
||||
@@ -24,62 +24,119 @@ if (GGML_METAL_NDEBUG)
|
||||
endif()
|
||||
|
||||
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
|
||||
set(METALLIB_KERNELS_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/kernels/common.h")
|
||||
set(METALLIB_KERNELS_DEQUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/dequantize.h")
|
||||
set(METALLIB_KERNELS_QUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantize.h")
|
||||
|
||||
set(METALLIB_KERNEL_SOURCES
|
||||
kernels/fa.metal
|
||||
kernels/mul_mv.metal
|
||||
kernels/mul_mm.metal
|
||||
kernels/quantize.metal
|
||||
kernels/softmax.metal
|
||||
kernels/norm.metal
|
||||
kernels/unary.metal
|
||||
kernels/binbcast.metal
|
||||
kernels/reduce.metal
|
||||
kernels/tri.metal
|
||||
kernels/ssm.metal
|
||||
kernels/wkv.metal
|
||||
kernels/gated_delta_net.metal
|
||||
kernels/solve_tri.metal
|
||||
kernels/rope.metal
|
||||
kernels/conv.metal
|
||||
kernels/upscale.metal
|
||||
kernels/argsort.metal
|
||||
kernels/pool.metal
|
||||
kernels/misc.metal
|
||||
)
|
||||
|
||||
if (GGML_METAL_EMBED_LIBRARY)
|
||||
enable_language(ASM)
|
||||
|
||||
add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
|
||||
|
||||
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
|
||||
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
|
||||
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
|
||||
|
||||
# merge ggml-common.h and ggml-metal.metal into a single file
|
||||
set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
|
||||
set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
|
||||
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
|
||||
set(METALLIB_EMBED_ASM_FILES "")
|
||||
foreach(src ${METALLIB_KERNEL_SOURCES})
|
||||
get_filename_component(kind ${src} NAME_WE)
|
||||
# symbol names must be valid C identifiers ('-' is not allowed)
|
||||
string(REPLACE "-" "_" kind_sym ${kind})
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo "Embedding Metal library"
|
||||
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
|
||||
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
|
||||
COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
|
||||
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
|
||||
COMMENT "Generate assembly for embedded Metal library"
|
||||
VERBATIM
|
||||
)
|
||||
set(SRC "${CMAKE_CURRENT_SOURCE_DIR}/kernels/${kind}.metal")
|
||||
set(EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.metal")
|
||||
set(ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.s")
|
||||
|
||||
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
|
||||
# only prepend headers that this source actually includes
|
||||
set(HEADERS_FOR_SRC ${METALLIB_KERNELS_COMMON})
|
||||
file(STRINGS ${SRC} _has_dequantize REGEX "#include \"dequantize\\.h\"")
|
||||
file(STRINGS ${SRC} _has_quantize REGEX "#include \"quantize\\.h\"")
|
||||
if(_has_dequantize)
|
||||
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_DEQUANTIZE})
|
||||
endif()
|
||||
if(_has_quantize)
|
||||
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_QUANTIZE})
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${ASM}"
|
||||
# Step 1: concatenate shared headers + this kernel source
|
||||
COMMAND cat ${HEADERS_FOR_SRC} ${SRC} > "${EMBED}.tmp1"
|
||||
# Step 2: remove internal #include and #pragma once
|
||||
COMMAND sed -e "/\#include \"common.h\"/d" -e "/\#include \"dequantize.h\"/d" -e "/\#include \"quantize.h\"/d" -e "/\#pragma once/d" < "${EMBED}.tmp1" > "${EMBED}.tmp2"
|
||||
# Step 3: inline ggml-common.h (replacing __embed_ggml-common.h__ sentinel)
|
||||
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${EMBED}.tmp2" > "${EMBED}.tmp3"
|
||||
# Step 4: inline ggml-metal-impl.h
|
||||
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${EMBED}.tmp3" > "${EMBED}"
|
||||
# Step 5: emit an asm chunk with kind-specific start/end symbols
|
||||
# note: '-' is illegal in C symbols, so we use kind_sym; the macOS
|
||||
# section name is limited to 16 chars so we keep it shared
|
||||
# across kinds (__ggml_metallib) and only vary the global symbols.
|
||||
COMMAND echo ".section __DATA,__ggml_metallib" > "${ASM}"
|
||||
COMMAND echo ".globl _ggml_metallib_${kind_sym}_start" >> "${ASM}"
|
||||
COMMAND echo "_ggml_metallib_${kind_sym}_start:" >> "${ASM}"
|
||||
COMMAND echo .incbin "\"${EMBED}\"" >> "${ASM}"
|
||||
COMMAND echo ".globl _ggml_metallib_${kind_sym}_end" >> "${ASM}"
|
||||
COMMAND echo "_ggml_metallib_${kind_sym}_end:" >> "${ASM}"
|
||||
DEPENDS ../ggml-common.h ggml-metal-impl.h
|
||||
kernels/common.h kernels/dequantize.h kernels/quantize.h
|
||||
kernels/${kind}.metal
|
||||
COMMENT "Generate embedded Metal library for ${kind}"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
list(APPEND METALLIB_EMBED_ASM_FILES "${ASM}")
|
||||
endforeach()
|
||||
|
||||
target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM_FILES})
|
||||
else()
|
||||
# copy metal files to bin directory
|
||||
# copy header files to bin directory
|
||||
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
|
||||
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
||||
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels")
|
||||
configure_file(kernels/common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/common.h COPYONLY)
|
||||
configure_file(kernels/dequantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/dequantize.h COPYONLY)
|
||||
configure_file(kernels/quantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/quantize.h COPYONLY)
|
||||
|
||||
foreach(src ${METALLIB_KERNEL_SOURCES})
|
||||
configure_file(${src} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} COPYONLY)
|
||||
endforeach()
|
||||
|
||||
if (GGML_METAL_SHADER_DEBUG)
|
||||
# custom command to do the following:
|
||||
# xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
|
||||
# xcrun -sdk macosx metallib ggml-metal.air -o default.metallib
|
||||
#
|
||||
# note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
|
||||
# disabling fast math is needed in order to pass tests/test-backend-ops
|
||||
# note: disabling fast math is needed in order to pass tests/test-backend-ops
|
||||
# note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
|
||||
# note: unfortunately, we have to call it default.metallib instead of ggml.metallib
|
||||
# ref: https://github.com/ggml-org/whisper.cpp/issues/1720
|
||||
# note: adding -g causes segmentation fault during compile
|
||||
#set(XC_FLAGS -fno-fast-math -fno-inline -g)
|
||||
set(XC_FLAGS -fno-fast-math -fno-inline)
|
||||
else()
|
||||
set(XC_FLAGS -O3)
|
||||
endif()
|
||||
|
||||
# Append macOS metal versioning flags
|
||||
if (GGML_METAL_MACOSX_VERSION_MIN)
|
||||
message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
|
||||
list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
|
||||
@@ -90,35 +147,46 @@ else()
|
||||
list (APPEND XC_FLAGS -std=${GGML_METAL_STD})
|
||||
endif()
|
||||
|
||||
# Compile each kernel source to .air, then link into default.metallib
|
||||
set(AIR_FILES "")
|
||||
foreach(src ${METALLIB_KERNEL_SOURCES})
|
||||
get_filename_component(name ${src} NAME_WE)
|
||||
set(AIR "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${name}.air")
|
||||
list(APPEND AIR_FILES ${AIR})
|
||||
add_custom_command(
|
||||
OUTPUT ${AIR}
|
||||
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -I ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} -o ${AIR}
|
||||
DEPENDS ${src} kernels/common.h kernels/dequantize.h kernels/quantize.h ${METALLIB_COMMON} ggml-metal-impl.h
|
||||
COMMENT "Compiling ${src}"
|
||||
VERBATIM
|
||||
)
|
||||
endforeach()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
|
||||
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${AIR_FILES} -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
|
||||
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
|
||||
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
|
||||
COMMENT "Compiling Metal kernels"
|
||||
)
|
||||
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h
|
||||
COMMAND rm -rf ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels
|
||||
DEPENDS ${AIR_FILES}
|
||||
COMMENT "Linking Metal kernels into default.metallib"
|
||||
)
|
||||
|
||||
# FIXME: only add to the ggml-metal target?
|
||||
add_custom_target(
|
||||
ggml-metal-lib ALL
|
||||
DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
)
|
||||
)
|
||||
endif() # GGML_METAL_EMBED_LIBRARY
|
||||
|
||||
if (NOT GGML_METAL_EMBED_LIBRARY)
|
||||
install(
|
||||
FILES src/ggml-metal/ggml-metal.metal
|
||||
PERMISSIONS
|
||||
OWNER_READ
|
||||
OWNER_WRITE
|
||||
GROUP_READ
|
||||
WORLD_READ
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/kernels/
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR}/kernels
|
||||
FILES_MATCHING PATTERN "*.metal" PATTERN "*.h"
|
||||
)
|
||||
|
||||
install(
|
||||
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
)
|
||||
install(
|
||||
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -94,8 +94,63 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi
|
||||
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
|
||||
}
|
||||
|
||||
//
|
||||
// MTLLibrary collection (one library per op-source, compiled separately)
|
||||
//
|
||||
|
||||
// Single source of truth for the per-kind metal libraries. The order here
|
||||
// defines the enum values and every per-kind table below, so adding a library
|
||||
// is a one-line change here (plus adding its source to CMakeLists.txt).
|
||||
// X(suffix, name): name is both the kernels/<name>.metal basename and the
|
||||
// ggml_metallib_<name>_{start,end} embed-symbol stem.
|
||||
#define GGML_METAL_LIBS \
|
||||
X(FA, fa) \
|
||||
X(MUL_MV, mul_mv) \
|
||||
X(MUL_MM, mul_mm) \
|
||||
X(QUANTIZE, quantize) \
|
||||
X(SOFTMAX, softmax) \
|
||||
X(NORM, norm) \
|
||||
X(UNARY, unary) \
|
||||
X(BINBCAST, binbcast) \
|
||||
X(REDUCE, reduce) \
|
||||
X(TRI, tri) \
|
||||
X(SSM, ssm) \
|
||||
X(WKV, wkv) \
|
||||
X(GATED_DELTA_NET, gated_delta_net)\
|
||||
X(SOLVE_TRI, solve_tri) \
|
||||
X(ROPE, rope) \
|
||||
X(CONV, conv) \
|
||||
X(UPSCALE, upscale) \
|
||||
X(ARGSORT, argsort) \
|
||||
X(POOL, pool) \
|
||||
X(MISC, misc)
|
||||
|
||||
enum ggml_metal_lib_kind {
|
||||
#define X(e, s) GGML_METAL_LIB_##e,
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
GGML_METAL_LIB_COUNT,
|
||||
};
|
||||
|
||||
static const char * const k_lib_names[GGML_METAL_LIB_COUNT] = {
|
||||
#define X(e, s) [GGML_METAL_LIB_##e] = #s,
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
};
|
||||
|
||||
struct ggml_metal_library {
|
||||
id<MTLLibrary> obj;
|
||||
// Per-kind compiled libraries. When single_library is true, the whole library
|
||||
// (e.g. a pre-compiled default.metallib or a from-source build) lives at
|
||||
// objs[0] and the remaining slots are nil.
|
||||
id<MTLLibrary> objs[GGML_METAL_LIB_COUNT];
|
||||
bool single_library; // true: combined library at objs[0]; false: per-kind libs in objs[*]
|
||||
|
||||
// Routing table: kernel function name -> objs[] index, populated from each
|
||||
// compiled library's -[MTLLibrary functionNames]. The actual compiled
|
||||
// libraries are the single source of truth for which library owns a kernel,
|
||||
// so adding kernels later requires no manual routing maintenance.
|
||||
// nil in single_library mode (everything resolves to objs[0]).
|
||||
NSMutableDictionary<NSString *, NSNumber *> * fn_to_lib;
|
||||
|
||||
ggml_metal_device_t dev;
|
||||
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
|
||||
@@ -103,160 +158,376 @@ struct ggml_metal_library {
|
||||
NSLock * lock;
|
||||
};
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
id<MTLLibrary> library = nil;
|
||||
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
|
||||
// Build the fn_to_lib routing table by querying each compiled library's public
|
||||
// function names. Call once after all per-kind libraries have been compiled.
|
||||
static void ggml_metal_library_build_index(ggml_metal_library_t lib) {
|
||||
@autoreleasepool {
|
||||
NSMutableDictionary<NSString *, NSNumber *> * index = [[NSMutableDictionary alloc] init];
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
for (NSString * fname in [lib->objs[kind] functionNames]) {
|
||||
index[fname] = @(kind);
|
||||
}
|
||||
}
|
||||
lib->fn_to_lib = index;
|
||||
}
|
||||
}
|
||||
|
||||
// load library
|
||||
//
|
||||
// - first check if the library is embedded
|
||||
// - then check if the library is in the bundle
|
||||
// - if not found, load the source and compile it
|
||||
// - if that fails, return NULL
|
||||
//
|
||||
// TODO: move to a function
|
||||
{
|
||||
const int64_t t_start = ggml_time_us();
|
||||
// Parse a `#include "name"` line. Returns the quoted name in *include_name on
|
||||
// success. Whitespace-tolerant; ignores `#include <...>` (system headers).
|
||||
static bool ggml_metal_library_parse_quoted_include(NSString * line, NSString ** include_name) {
|
||||
NSScanner * scanner = [NSScanner scannerWithString:line];
|
||||
scanner.charactersToBeSkipped = [NSCharacterSet whitespaceCharacterSet];
|
||||
|
||||
NSError * error = nil;
|
||||
NSString * src = nil;
|
||||
if (![scanner scanString:@"#" intoString:NULL] ||
|
||||
![scanner scanString:@"include" intoString:NULL] ||
|
||||
![scanner scanString:@"\"" intoString:NULL]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
|
||||
NSString * name = nil;
|
||||
if (![scanner scanUpToString:@"\"" intoString:&name]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
extern const char ggml_metallib_start[];
|
||||
extern const char ggml_metallib_end[];
|
||||
if (include_name) {
|
||||
*include_name = name;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
|
||||
#else
|
||||
// Recursively inline `#include "name"` directives. System includes (<...>),
|
||||
// `#if/#else/#endif`, and other preprocessor lines are passed through to the
|
||||
// Metal compiler unchanged. `#pragma once` is dropped since `seen` already
|
||||
// guards against double-inclusion.
|
||||
static bool ggml_metal_library_flatten_file(NSMutableString * dst, NSString * path,
|
||||
NSArray<NSString *> * search_paths,
|
||||
NSMutableSet<NSString *> * seen, NSError ** error) {
|
||||
NSString * key = [path stringByStandardizingPath];
|
||||
if ([seen containsObject:key]) {
|
||||
return true;
|
||||
}
|
||||
[seen addObject:key];
|
||||
|
||||
#ifdef SWIFT_PACKAGE
|
||||
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
|
||||
#else
|
||||
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||
#endif
|
||||
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:error];
|
||||
if (!src) {
|
||||
return false;
|
||||
}
|
||||
|
||||
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
||||
if (path_lib == nil) {
|
||||
// Try to find the resource in the directory where the current binary located.
|
||||
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
|
||||
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
|
||||
NSFileManager * fm = [NSFileManager defaultManager];
|
||||
for (NSString * line in [src componentsSeparatedByString:@"\n"]) {
|
||||
NSString * trimmed = [line stringByTrimmingCharactersInSet:[NSCharacterSet whitespaceCharacterSet]];
|
||||
if ([trimmed isEqualToString:@"#pragma once"]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
|
||||
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
|
||||
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
|
||||
|
||||
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
|
||||
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
|
||||
// Optionally, if this is a symlink, try to resolve it.
|
||||
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
|
||||
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
|
||||
// It is a relative path, adding the binary directory as directory prefix.
|
||||
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
|
||||
}
|
||||
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
|
||||
// Link to the resource could not be resolved.
|
||||
path_lib_default = nil;
|
||||
} else {
|
||||
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
|
||||
}
|
||||
NSString * include_name = nil;
|
||||
if (ggml_metal_library_parse_quoted_include(line, &include_name)) {
|
||||
NSString * resolved = nil;
|
||||
for (NSString * dir in search_paths) {
|
||||
NSString * candidate = [dir stringByAppendingPathComponent:include_name];
|
||||
if ([fm isReadableFileAtPath:candidate]) {
|
||||
resolved = candidate;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// The resource couldn't be found in the binary's directory.
|
||||
path_lib_default = nil;
|
||||
}
|
||||
|
||||
path_lib = path_lib_default;
|
||||
if (!resolved) {
|
||||
if (error) {
|
||||
NSString * msg = [NSString stringWithFormat:@"could not resolve include \"%@\" from '%@'", include_name, path];
|
||||
*error = [NSError errorWithDomain:@"ggml-metal-source-flatten" code:1
|
||||
userInfo:@{NSLocalizedDescriptionKey: msg}];
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!ggml_metal_library_flatten_file(dst, resolved, search_paths, seen, error)) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (path_lib != nil) {
|
||||
// pre-compiled library found
|
||||
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
||||
[dst appendString:line];
|
||||
[dst appendString:@"\n"];
|
||||
}
|
||||
|
||||
library = [device newLibraryWithURL:libURL error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return nil;
|
||||
}
|
||||
} else {
|
||||
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
||||
return true;
|
||||
}
|
||||
|
||||
NSString * path_source;
|
||||
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
||||
static NSString * ggml_metal_library_flatten_source(NSString * path_source, NSError ** error) {
|
||||
// Search paths cover both runtime layout (build/bin/kernels + build/bin)
|
||||
// and source-tree layout (ggml/src/ggml-metal/kernels + ggml/src/ggml-metal + ggml/src).
|
||||
NSString * path_kernels = [path_source stringByDeletingLastPathComponent];
|
||||
NSString * path_base = [path_kernels stringByDeletingLastPathComponent];
|
||||
NSArray<NSString *> * search_paths = @[
|
||||
path_kernels,
|
||||
path_base,
|
||||
[path_base stringByDeletingLastPathComponent],
|
||||
];
|
||||
|
||||
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
|
||||
NSMutableString * src = [[NSMutableString alloc] init];
|
||||
NSMutableSet<NSString *> * seen = [NSMutableSet set];
|
||||
|
||||
if (path_resource) {
|
||||
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
|
||||
} else {
|
||||
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
||||
if (!ggml_metal_library_flatten_file(src, path_source, search_paths, seen, error)) {
|
||||
[src release];
|
||||
return nil;
|
||||
}
|
||||
return src;
|
||||
}
|
||||
|
||||
// Compile all per-kind libraries in parallel. `source_for_kind` returns the MSL
|
||||
// source for a kind (the helper takes ownership and releases it), or nil with
|
||||
// *err set on failure. On success the objs[] slots are populated and the routing
|
||||
// index is built; on any failure every error is logged and false is returned
|
||||
// (the caller is responsible for freeing `res`).
|
||||
static bool ggml_metal_library_compile_all(
|
||||
ggml_metal_library_t res,
|
||||
id<MTLDevice> device,
|
||||
NSDictionary * prep,
|
||||
NSString * (^source_for_kind)(int kind, NSError ** err),
|
||||
const char * origin) {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
int64_t * t_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(int64_t));
|
||||
NSError ** err_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(NSError *));
|
||||
__block atomic_bool any_failure = false;
|
||||
|
||||
dispatch_group_t group = dispatch_group_create();
|
||||
dispatch_queue_t queue = dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0);
|
||||
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
dispatch_group_async(group, queue, ^{
|
||||
|
||||
const int64_t t0 = ggml_time_us();
|
||||
|
||||
NSError * error = nil;
|
||||
|
||||
NSString * src = source_for_kind(kind, &error);
|
||||
if (!src) {
|
||||
err_per_lib[kind] = [error retain];
|
||||
atomic_store(&any_failure, true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (path_source == nil) {
|
||||
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
||||
path_source = @"ggml-metal.metal";
|
||||
}
|
||||
id<MTLLibrary> lib = nil;
|
||||
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
|
||||
|
||||
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!library) {
|
||||
@autoreleasepool {
|
||||
// dictionary of preprocessor macros
|
||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||
|
||||
if (ggml_metal_device_get_props(dev)->has_bfloat) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
|
||||
}
|
||||
|
||||
if (ggml_metal_device_get_props(dev)->has_tensor) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
|
||||
}
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
|
||||
#endif
|
||||
|
||||
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||
options.preprocessorMacros = prep;
|
||||
|
||||
//[options setFastMathEnabled:false];
|
||||
lib = [device newLibraryWithSource:src options:options error:&error];
|
||||
|
||||
library = [device newLibraryWithSource:src options:options error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return nil;
|
||||
}
|
||||
|
||||
#if !__has_feature(objc_arc)
|
||||
[options release];
|
||||
#endif
|
||||
|
||||
// retain the error before the autorelease pool drains it
|
||||
if (!lib) {
|
||||
err_per_lib[kind] = [error retain];
|
||||
}
|
||||
}
|
||||
|
||||
[src release];
|
||||
|
||||
t_per_lib[kind] = ggml_time_us() - t0;
|
||||
|
||||
if (!lib) {
|
||||
atomic_store(&any_failure, true);
|
||||
return;
|
||||
}
|
||||
|
||||
res->objs[kind] = lib;
|
||||
});
|
||||
}
|
||||
dispatch_group_wait(group, DISPATCH_TIME_FOREVER);
|
||||
dispatch_release(group);
|
||||
|
||||
const bool ok = !atomic_load(&any_failure);
|
||||
|
||||
if (ok) {
|
||||
const int64_t t_total = ggml_time_us() - t_start;
|
||||
int64_t t_max = 0;
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
GGML_LOG_DEBUG("%s: compiled '%s' library in %.3f sec\n",
|
||||
__func__, k_lib_names[kind], t_per_lib[kind] / 1e6);
|
||||
if (t_per_lib[kind] > t_max) t_max = t_per_lib[kind];
|
||||
}
|
||||
GGML_LOG_INFO("%s: loaded %d libraries from %s in %.3f sec (max single = %.3f sec)\n",
|
||||
__func__, GGML_METAL_LIB_COUNT, origin, t_total / 1e6, t_max / 1e6);
|
||||
|
||||
ggml_metal_library_build_index(res);
|
||||
} else {
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
if (err_per_lib[kind]) {
|
||||
GGML_LOG_ERROR("%s: failed to build '%s' library: %s\n", __func__,
|
||||
k_lib_names[kind], [[err_per_lib[kind] description] UTF8String]);
|
||||
[err_per_lib[kind] release];
|
||||
}
|
||||
}
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[src release];
|
||||
#endif // GGML_METAL_EMBED_LIBRARY
|
||||
|
||||
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
|
||||
}
|
||||
|
||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||
free(err_per_lib);
|
||||
free(t_per_lib);
|
||||
|
||||
res->obj = library;
|
||||
return ok;
|
||||
}
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
|
||||
|
||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||
res->dev = dev;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
// shared MTLCompileOptions preprocessor macros (matches the build-time defines)
|
||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||
if (ggml_metal_device_get_props(dev)->has_bfloat) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
|
||||
}
|
||||
if (ggml_metal_device_get_props(dev)->has_tensor) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
|
||||
}
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
|
||||
#endif
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
|
||||
|
||||
// start/end symbols emitted by CMake (see CMakeLists.txt), one pair per kind
|
||||
#define X(e, s) extern const char ggml_metallib_##s##_start[]; extern const char ggml_metallib_##s##_end[];
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
|
||||
static const char * const lib_start[GGML_METAL_LIB_COUNT] = {
|
||||
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_start,
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
};
|
||||
static const char * const lib_end[GGML_METAL_LIB_COUNT] = {
|
||||
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_end,
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
};
|
||||
|
||||
const bool ok = ggml_metal_library_compile_all(res, device, prep,
|
||||
^NSString * (int kind, NSError ** err) {
|
||||
(void) err;
|
||||
return [[NSString alloc] initWithBytes:lib_start[kind]
|
||||
length:(lib_end[kind] - lib_start[kind])
|
||||
encoding:NSUTF8StringEncoding];
|
||||
}, "embedded data");
|
||||
|
||||
if (!ok) {
|
||||
ggml_metal_library_free(res);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return res;
|
||||
#else
|
||||
#ifdef SWIFT_PACKAGE
|
||||
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
|
||||
#else
|
||||
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||
#endif
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
NSError * error = nil;
|
||||
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
||||
if (path_lib == nil) {
|
||||
// Try to find the resource in the directory where the current binary located.
|
||||
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
|
||||
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
|
||||
|
||||
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
|
||||
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
|
||||
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
|
||||
|
||||
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
|
||||
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
|
||||
// Optionally, if this is a symlink, try to resolve it.
|
||||
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
|
||||
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
|
||||
// It is a relative path, adding the binary directory as directory prefix.
|
||||
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
|
||||
}
|
||||
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
|
||||
// Link to the resource could not be resolved.
|
||||
path_lib_default = nil;
|
||||
} else {
|
||||
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// The resource couldn't be found in the binary's directory.
|
||||
path_lib_default = nil;
|
||||
}
|
||||
|
||||
path_lib = path_lib_default;
|
||||
}
|
||||
|
||||
if (path_lib != nil) {
|
||||
// pre-compiled library found: a single combined default.metallib
|
||||
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
||||
|
||||
res->objs[0] = [device newLibraryWithURL:libURL error:&error];
|
||||
res->single_library = true;
|
||||
if (!res->objs[0]) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
ggml_metal_library_free(res);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
|
||||
return res;
|
||||
}
|
||||
|
||||
// no pre-compiled metallib: fall back to compiling each kernel source separately
|
||||
GGML_LOG_INFO("%s: default.metallib not found, loading kernel sources\n", __func__);
|
||||
|
||||
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
||||
if (path_resource) {
|
||||
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, [path_resource UTF8String]);
|
||||
}
|
||||
|
||||
// resolve each kind's source path up front (file lookup/logging stays on the calling thread)
|
||||
NSString ** path_per_kind = calloc(GGML_METAL_LIB_COUNT, sizeof(NSString *));
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
NSString * rel = [NSString stringWithFormat:@"kernels/%s.metal", k_lib_names[kind]];
|
||||
|
||||
NSString * path_source = nil;
|
||||
if (path_resource) {
|
||||
path_source = [path_resource stringByAppendingPathComponent:rel];
|
||||
} else {
|
||||
NSString * stem = [NSString stringWithFormat:@"kernels/%s", k_lib_names[kind]];
|
||||
path_source = [bundle pathForResource:stem ofType:@"metal"];
|
||||
}
|
||||
|
||||
if (path_source == nil || ![[NSFileManager defaultManager] isReadableFileAtPath:path_source]) {
|
||||
GGML_LOG_WARN("%s: could not locate %s in bundle, falling back to cwd\n", __func__, [rel UTF8String]);
|
||||
path_source = rel;
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("%s: loading '%s'\n", __func__, [path_source UTF8String]);
|
||||
|
||||
path_per_kind[kind] = [path_source retain];
|
||||
}
|
||||
|
||||
const bool ok = ggml_metal_library_compile_all(res, device, prep,
|
||||
^NSString * (int kind, NSError ** err) {
|
||||
return ggml_metal_library_flatten_source(path_per_kind[kind], err);
|
||||
}, "source");
|
||||
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
[path_per_kind[kind] release];
|
||||
}
|
||||
free(path_per_kind);
|
||||
|
||||
if (!ok) {
|
||||
ggml_metal_library_free(res);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
|
||||
@@ -318,10 +589,11 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||
return NULL;
|
||||
}
|
||||
|
||||
res->obj = library;
|
||||
res->dev = dev;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
res->objs[0] = library;
|
||||
res->single_library = true;
|
||||
res->dev = dev;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -331,8 +603,14 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (lib->obj) {
|
||||
[lib->obj release];
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
if (lib->objs[kind]) {
|
||||
[lib->objs[kind] release];
|
||||
}
|
||||
}
|
||||
|
||||
if (lib->fn_to_lib) {
|
||||
[lib->fn_to_lib release];
|
||||
}
|
||||
|
||||
ggml_metal_pipelines_free(lib->pipelines);
|
||||
@@ -393,11 +671,28 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
|
||||
|
||||
GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
|
||||
|
||||
// route to the library that actually defines this kernel; fn_to_lib is
|
||||
// built from -[MTLLibrary functionNames] so it's always in sync
|
||||
int lib_idx = 0;
|
||||
if (!lib->single_library) {
|
||||
NSNumber * idx = lib->fn_to_lib[base_func];
|
||||
if (!idx) {
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: kernel not found in any metal library: base = '%s', name = '%s'\n", __func__, base, name);
|
||||
|
||||
return res;
|
||||
}
|
||||
lib_idx = [idx intValue];
|
||||
}
|
||||
|
||||
id<MTLLibrary> mtl_lib = lib->objs[lib_idx];
|
||||
|
||||
id<MTLFunction> mtl_function;
|
||||
if (!cv) {
|
||||
mtl_function = [lib->obj newFunctionWithName:base_func];
|
||||
mtl_function = [mtl_lib newFunctionWithName:base_func];
|
||||
} else {
|
||||
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
|
||||
mtl_function = [mtl_lib newFunctionWithName:base_func constantValues:cv->obj error:&error];
|
||||
}
|
||||
if (!mtl_function) {
|
||||
[lib->lock unlock];
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,232 @@
|
||||
#include "common.h"
|
||||
|
||||
// bitonic sort implementation following the CUDA kernels as reference
|
||||
typedef void (argsort_t)(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template<ggml_sort_order order>
|
||||
kernel void kernel_argsort_f32_i32(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// bitonic sort
|
||||
const int col = tpitg[0];
|
||||
const int ib = tgpig[0] / args.ne01;
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0] % args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
||||
|
||||
// initialize indices
|
||||
shmem_i32[col] = i00 + col;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int k = 2; k <= ntg.x; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (shmem_i32[col] >= args.ne00 ||
|
||||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
|
||||
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (shmem_i32[ixj] >= args.ne00 ||
|
||||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
|
||||
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t i0 = ib*args.top_k;
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (i0 + col < args.ne0 && col < args.top_k) {
|
||||
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
|
||||
|
||||
dst[col] = shmem_i32[col];
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
|
||||
typedef void (argsort_merge_t)(
|
||||
constant ggml_metal_kargs_argsort_merge & args,
|
||||
device const char * src0,
|
||||
device const int32_t * tmp,
|
||||
device int32_t * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template<ggml_sort_order order>
|
||||
kernel void kernel_argsort_merge_f32_i32(
|
||||
constant ggml_metal_kargs_argsort_merge & args,
|
||||
device const char * src0,
|
||||
device const int32_t * tmp,
|
||||
device int32_t * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int im = tgpig[0] / args.ne01;
|
||||
const int i01 = tgpig[0] % args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
const int start = im * (2 * args.len);
|
||||
|
||||
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
||||
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
||||
|
||||
const int total = len0 + len1;
|
||||
|
||||
device const int32_t * tmp0 = tmp + start
|
||||
+ i01*args.ne0
|
||||
+ i02*args.ne0*args.ne01
|
||||
+ i03*args.ne0*args.ne01*args.ne02;
|
||||
|
||||
device const int32_t * tmp1 = tmp0 + args.len;
|
||||
|
||||
dst += start
|
||||
+ i01*args.top_k
|
||||
+ i02*args.top_k*args.ne01
|
||||
+ i03*args.top_k*args.ne01*args.ne02;
|
||||
|
||||
device const float * src0_row = (device const float *)(src0
|
||||
+ args.nb01*i01
|
||||
+ args.nb02*i02
|
||||
+ args.nb03*i03);
|
||||
|
||||
if (total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||
|
||||
const int k0 = tpitg.x * chunk;
|
||||
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
||||
|
||||
if (k0 >= args.top_k) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (k0 >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
int low = k0 > len1 ? k0 - len1 : 0;
|
||||
int high = MIN(k0, len0);
|
||||
|
||||
// binary-search partition (i, j) such that i + j = k
|
||||
while (low < high) {
|
||||
const int mid = (low + high) >> 1;
|
||||
|
||||
const int32_t idx0 = tmp0[mid];
|
||||
const int32_t idx1 = tmp1[k0 - mid - 1];
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
|
||||
bool take_left;
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
if (take_left) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
}
|
||||
|
||||
int i = low;
|
||||
int j = k0 - i;
|
||||
|
||||
// keep the merge fronts into registers
|
||||
int32_t idx0 = 0;
|
||||
float val0 = 0.0f;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
|
||||
int32_t idx1 = 0;
|
||||
float val1 = 0.0f;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
|
||||
for (int k = k0; k < k1; ++k) {
|
||||
int32_t out_idx;
|
||||
|
||||
if (i >= len0) {
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp1[j++];
|
||||
}
|
||||
break;
|
||||
} else if (j >= len1) {
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp0[i++];
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
bool take_left;
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
if (take_left) {
|
||||
out_idx = idx0;
|
||||
++i;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
} else {
|
||||
out_idx = idx1;
|
||||
++j;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst[k] = out_idx;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
@@ -0,0 +1,226 @@
|
||||
#include "common.h"
|
||||
|
||||
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
|
||||
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
||||
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
||||
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
|
||||
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
|
||||
|
||||
template <typename T0, typename T1, typename T>
|
||||
kernel void kernel_bin_fuse_impl(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
#define FC_OP FC_bin_op
|
||||
#define FC_F FC_bin_f
|
||||
#define FC_RB FC_bin_rb
|
||||
#define FC_CB FC_bin_cb
|
||||
|
||||
if (FC_RB) {
|
||||
// row broadcast
|
||||
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
|
||||
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
|
||||
|
||||
device const T0 * src0_row = (device const T0 *) (src0);
|
||||
device T * dst_row = (device T *) (dst);
|
||||
|
||||
if (FC_F == 1) {
|
||||
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
|
||||
|
||||
if (FC_OP == 0) {
|
||||
dst_row[i0] = src0_row[i0] + src1_row[i1];
|
||||
}
|
||||
|
||||
if (FC_OP == 1) {
|
||||
dst_row[i0] = src0_row[i0] - src1_row[i1];
|
||||
}
|
||||
|
||||
if (FC_OP == 2) {
|
||||
dst_row[i0] = src0_row[i0] * src1_row[i1];
|
||||
}
|
||||
|
||||
if (FC_OP == 3) {
|
||||
dst_row[i0] = src0_row[i0] / src1_row[i1];
|
||||
}
|
||||
} else {
|
||||
T0 res = src0_row[i0];
|
||||
|
||||
if (FC_OP == 0) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 1) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 2) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 3) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||
}
|
||||
}
|
||||
|
||||
dst_row[i0] = res;
|
||||
}
|
||||
} else {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
||||
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
||||
|
||||
if (FC_F == 1) {
|
||||
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
||||
|
||||
if (FC_OP == 0) {
|
||||
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
|
||||
}
|
||||
|
||||
if (FC_OP == 1) {
|
||||
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
|
||||
}
|
||||
|
||||
if (FC_OP == 2) {
|
||||
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
|
||||
}
|
||||
|
||||
if (FC_OP == 3) {
|
||||
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
device const T1 * src1_ptr[8];
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
}
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
||||
|
||||
T res = src0_ptr[i0];
|
||||
|
||||
if (FC_OP == 0) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res += src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 1) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res -= src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 2) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res *= src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 3) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res /= src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
dst_ptr[i0] = res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
#undef FC_F
|
||||
#undef FC_RB
|
||||
#undef FC_CB
|
||||
}
|
||||
|
||||
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
|
||||
|
||||
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
|
||||
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
|
||||
|
||||
kernel void kernel_add_id(
|
||||
constant ggml_metal_kargs_add_id & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i1 = tgpig.x;
|
||||
const int i2 = tgpig.y;
|
||||
|
||||
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
|
||||
|
||||
const size_t nb1 = args.ne0 * sizeof(float);
|
||||
const size_t nb2 = args.ne1 * nb1;
|
||||
|
||||
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
||||
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
|
||||
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_repeat(
|
||||
constant ggml_metal_kargs_repeat & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
const int i03 = i3%args.ne03;
|
||||
const int i02 = i2%args.ne02;
|
||||
const int i01 = i1%args.ne01;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
||||
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i00 = i0%args.ne00;
|
||||
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
||||
|
||||
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
||||
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_repeat_bf16")]] kernel kernel_repeat_t kernel_repeat<bfloat>;
|
||||
#endif
|
||||
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
||||
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
||||
@@ -0,0 +1,126 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-metal-impl.h"
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
#include <metal_tensor>
|
||||
|
||||
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
|
||||
#endif
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
||||
|
||||
#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
|
||||
|
||||
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
|
||||
|
||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
//
|
||||
// cmd:
|
||||
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/kernels/<src>.metal
|
||||
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/kernels/<src>.metal
|
||||
//
|
||||
#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
|
||||
#undef GGML_METAL_HAS_BF16
|
||||
#endif
|
||||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
typedef matrix<bfloat, 4, 4> bfloat4x4;
|
||||
typedef matrix<bfloat, 2, 4> bfloat2x4;
|
||||
#endif
|
||||
|
||||
constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||
};
|
||||
|
||||
constexpr constant static float kvalues_mxfp4_f[16] = {
|
||||
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
|
||||
};
|
||||
|
||||
static inline int best_index_int8(int n, constant float * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
int ml = 0, mu = n-1;
|
||||
while (mu-ml > 1) {
|
||||
int mav = (ml+mu)/2;
|
||||
if (x < val[mav]) mu = mav; else ml = mav;
|
||||
}
|
||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uint8_t x) {
|
||||
uint32_t bits;
|
||||
|
||||
if (x == 0) {
|
||||
bits = 0x00400000;
|
||||
} else {
|
||||
bits = (uint32_t) x << 23;
|
||||
}
|
||||
|
||||
return as_type<float>(bits);
|
||||
}
|
||||
|
||||
static inline float dot(float x, float y) {
|
||||
return x*y;
|
||||
}
|
||||
|
||||
static inline float sum(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline float sum(float4 x) {
|
||||
return x[0] + x[1] + x[2] + x[3];
|
||||
}
|
||||
|
||||
enum ggml_sort_order {
|
||||
GGML_SORT_ORDER_ASC,
|
||||
GGML_SORT_ORDER_DESC,
|
||||
};
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float GELU_QUICK_COEF = -1.702f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
constant float p_erf = 0.3275911f;
|
||||
constant float a1_erf = 0.254829592f;
|
||||
constant float a2_erf = -0.284496736f;
|
||||
constant float a3_erf = 1.421413741f;
|
||||
constant float a4_erf = -1.453152027f;
|
||||
constant float a5_erf = 1.061405429f;
|
||||
|
||||
template<typename T>
|
||||
inline T erf_approx(T x) {
|
||||
T sign_x = sign(x);
|
||||
x = fabs(x);
|
||||
T t = 1.0f / (1.0f + p_erf * x);
|
||||
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
return sign_x * y;
|
||||
}
|
||||
|
||||
template<typename T> T elu_approx(T x);
|
||||
|
||||
template<> inline float elu_approx<float>(float x) {
|
||||
return (x > 0.f) ? x : (exp(x) - 1);
|
||||
}
|
||||
|
||||
template<> inline float4 elu_approx<float4>(float4 x) {
|
||||
float4 res;
|
||||
|
||||
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
|
||||
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
||||
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
||||
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
#include "common.h"
|
||||
|
||||
typedef void (im2col_t)(
|
||||
constant ggml_metal_kargs_im2col & args,
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_im2col(
|
||||
constant ggml_metal_kargs_im2col & args,
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
// const int64_t IC = tgpg[0];
|
||||
const int64_t OH = tgpg[1];
|
||||
const int64_t OW = tgpg[2];
|
||||
|
||||
const int64_t KH = ntg[1];
|
||||
const int64_t KW = ntg[2];
|
||||
|
||||
int64_t in = tpitg[0];
|
||||
const int64_t ikh = tpitg[1];
|
||||
const int64_t ikw = tpitg[2];
|
||||
|
||||
const int64_t iic = tgpig[0];
|
||||
const int64_t ioh = tgpig[1];
|
||||
const int64_t iow = tgpig[2];
|
||||
|
||||
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
|
||||
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
|
||||
|
||||
int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
|
||||
|
||||
device T * pdst = (device T *) (dst);
|
||||
|
||||
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
||||
while (in < args.N) {
|
||||
pdst[offset_dst] = 0.0f;
|
||||
offset_dst += ntg[0]*args.CHW*OH*OW;
|
||||
|
||||
in += ntg[0];
|
||||
}
|
||||
} else {
|
||||
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
|
||||
|
||||
while (in < args.N) {
|
||||
pdst[offset_dst] = x[offset_src];
|
||||
|
||||
offset_dst += ntg[0]*args.CHW*OH*OW;
|
||||
offset_src += ntg[0]*args.ofs0;
|
||||
|
||||
in += ntg[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
||||
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
||||
|
||||
// TODO: optimize
|
||||
typedef void (im2col_ext_t)(
|
||||
constant ggml_metal_kargs_im2col & args,
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_im2col_ext(
|
||||
constant ggml_metal_kargs_im2col & args,
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
||||
const int64_t KHW = (int64_t)args.KHW;
|
||||
|
||||
const int64_t d = tgpig[0] / args.CHW;
|
||||
const int64_t chw = tgpig[0] % args.CHW;
|
||||
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
||||
const int64_t HW = tgpig[0] % KHW;
|
||||
|
||||
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
||||
if (tpitg_0 >= args.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t tpitg_1 = HW / args.KW;
|
||||
const int64_t tpitg_2 = HW % args.KW;
|
||||
|
||||
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
|
||||
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
|
||||
|
||||
const int64_t offset_dst =
|
||||
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
|
||||
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
|
||||
|
||||
device T * pdst = (device T *) (dst);
|
||||
|
||||
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
||||
pdst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
|
||||
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
||||
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
||||
|
||||
template <typename TK>
|
||||
kernel void kernel_conv_2d(
|
||||
constant ggml_metal_kargs_conv_2d & args,
|
||||
device const char * weights,
|
||||
device const char * src,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
|
||||
const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
|
||||
const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
|
||||
const uint thread_index = tg_index * threads_per_tg + local_thread;
|
||||
const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
|
||||
const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
|
||||
|
||||
for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
|
||||
uint64_t tmp = index;
|
||||
|
||||
const int32_t ow = tmp % args.OW; tmp /= args.OW;
|
||||
const int32_t oh = tmp % args.OH; tmp /= args.OH;
|
||||
const int32_t oc = tmp % args.OC; tmp /= args.OC;
|
||||
const int32_t n = tmp;
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
const int32_t base_x = ow*args.s0 - args.p0;
|
||||
const int32_t base_y = oh*args.s1 - args.p1;
|
||||
|
||||
int32_t ky_start = 0;
|
||||
if (base_y < 0) {
|
||||
ky_start = (-base_y + args.d1 - 1)/args.d1;
|
||||
}
|
||||
int32_t ky_end = args.KH;
|
||||
const int32_t y_max = args.IH - 1 - base_y;
|
||||
if (y_max < 0) {
|
||||
ky_end = ky_start;
|
||||
} else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
|
||||
ky_end = min(ky_end, y_max/args.d1 + 1);
|
||||
}
|
||||
|
||||
int32_t kx_start = 0;
|
||||
if (base_x < 0) {
|
||||
kx_start = (-base_x + args.d0 - 1)/args.d0;
|
||||
}
|
||||
int32_t kx_end = args.KW;
|
||||
const int32_t x_max = args.IW - 1 - base_x;
|
||||
if (x_max < 0) {
|
||||
kx_end = kx_start;
|
||||
} else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
|
||||
kx_end = min(kx_end, x_max/args.d0 + 1);
|
||||
}
|
||||
|
||||
if (ky_start < ky_end && kx_start < kx_end) {
|
||||
const uint64_t src_base_n = (uint64_t) n * args.nb13;
|
||||
const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
|
||||
|
||||
for (int32_t ic = 0; ic < args.IC; ++ic) {
|
||||
const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
|
||||
const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
|
||||
|
||||
for (int32_t ky = ky_start; ky < ky_end; ++ky) {
|
||||
const int32_t iy = base_y + ky*args.d1;
|
||||
const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
|
||||
const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
|
||||
|
||||
for (int32_t kx = kx_start; kx < kx_end; ++kx) {
|
||||
const int32_t ix = base_x + kx*args.d0;
|
||||
const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
|
||||
const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
|
||||
|
||||
const float x = *(device const float *)(src + src_offs);
|
||||
const float w = (float) (*(device const TK *)(weights + w_offs));
|
||||
|
||||
acc += x * w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const uint64_t dst_offs =
|
||||
(uint64_t) n * args.nb3 +
|
||||
(uint64_t) oc * args.nb2 +
|
||||
(uint64_t) oh * args.nb1 +
|
||||
(uint64_t) ow * args.nb0;
|
||||
|
||||
*(device float *)(dst + dst_offs) = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_conv_2d_f32_f32")]]
|
||||
kernel void kernel_conv_2d<float>(
|
||||
constant ggml_metal_kargs_conv_2d & args,
|
||||
device const char * weights,
|
||||
device const char * src,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template [[host_name("kernel_conv_2d_f16_f32")]]
|
||||
kernel void kernel_conv_2d<half>(
|
||||
constant ggml_metal_kargs_conv_2d & args,
|
||||
device const char * weights,
|
||||
device const char * src,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
typedef void (conv_transpose_1d_t)(
|
||||
constant ggml_metal_kargs_conv_transpose_1d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_conv_transpose_1d(
|
||||
constant ggml_metal_kargs_conv_transpose_1d & args,
|
||||
device const T * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||
|
||||
// For output position j on the time axis, only input positions
|
||||
// i such that i*s0 <= j < i*s0 + K
|
||||
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
|
||||
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
|
||||
// (typically 2 for stride==K/2 transposed convs).
|
||||
const int32_t j = tgpig[0];
|
||||
const int32_t s0 = args.s0;
|
||||
const int32_t K = args.K;
|
||||
const int32_t IL = args.IL;
|
||||
|
||||
int32_t i_min;
|
||||
{
|
||||
int32_t a = j - K + 1;
|
||||
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
|
||||
}
|
||||
int32_t i_max = j / s0;
|
||||
if (i_max > IL - 1) i_max = IL - 1;
|
||||
|
||||
float v = 0.0f;
|
||||
if (i_min <= i_max) {
|
||||
for (int64_t c = 0; c < args.IC; c++) {
|
||||
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
|
||||
const int32_t input_offset = c * IL;
|
||||
|
||||
for (int32_t i = i_min; i <= i_max; i++) {
|
||||
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
|
||||
|
||||
dst_ptr[0] = v;
|
||||
}
|
||||
|
||||
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
|
||||
kernel void kernel_conv_transpose_1d<float>(
|
||||
constant ggml_metal_kargs_conv_transpose_1d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
|
||||
kernel void kernel_conv_transpose_1d<half>(
|
||||
constant ggml_metal_kargs_conv_transpose_1d & args,
|
||||
device const half * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
|
||||
typedef void (conv_transpose_2d_t)(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_conv_transpose_2d(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const T * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t out_x = tgpig[0];
|
||||
const int64_t out_y = tgpig[1];
|
||||
const int64_t out_c = tgpig[2];
|
||||
|
||||
const int64_t kw = tpitg[0];
|
||||
const int64_t kh = tpitg[1];
|
||||
|
||||
float v = 0.0f;
|
||||
|
||||
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
|
||||
int64_t in_y = out_y - kh;
|
||||
|
||||
if (in_y < 0 || in_y % args.s0) continue;
|
||||
|
||||
in_y /= args.s0;
|
||||
|
||||
if (in_y >= args.IH) continue;
|
||||
|
||||
int64_t in_x = out_x - kw;
|
||||
|
||||
if (in_x < 0 || in_x % args.s0) continue;
|
||||
|
||||
in_x /= args.s0;
|
||||
|
||||
if (in_x >= args.IW) continue;
|
||||
|
||||
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
|
||||
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
|
||||
|
||||
v += (float)src0[kernel_idx] * src1[input_idx];
|
||||
}
|
||||
|
||||
const uint tid = tpitg.y * ntg.x + tpitg.x;
|
||||
shared_sum[tid] = v;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tid == 0) {
|
||||
float total = 0.0f;
|
||||
const uint num_threads = ntg.x * ntg.y;
|
||||
for (uint i = 0; i < num_threads; i++) {
|
||||
total += shared_sum[i];
|
||||
}
|
||||
|
||||
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
|
||||
dst_ptr[0] = total;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
|
||||
kernel void kernel_conv_transpose_2d<float>(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
|
||||
kernel void kernel_conv_transpose_2d<half>(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const half * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_conv_3d(
|
||||
constant ggml_metal_kargs_conv_3d & args,
|
||||
device const char * src0, // Weights [IC * OC, KD, KH, KW]
|
||||
device const char * src1, // Inputs [IC * N, ID, IH, IW]
|
||||
device char * dst, // Outputs [OC * N, OD, OH, OW]
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
||||
|
||||
// 1. Un-flatten the spatial dimension from Grid X
|
||||
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
|
||||
|
||||
if (spatial_idx >= args.OW * args.OH * args.OD) {
|
||||
return; // Thread falls outside the spatial volume
|
||||
}
|
||||
|
||||
int64_t od = spatial_idx / (args.OW * args.OH);
|
||||
int64_t oh = (spatial_idx / args.OW) % args.OH;
|
||||
int64_t ow = spatial_idx % args.OW;
|
||||
|
||||
// 2. Map Y to Channels, Z to Batch
|
||||
int64_t oc = tgpig.y;
|
||||
int64_t batch_idx = tgpig.z;
|
||||
|
||||
// 3. Calculate anchor coordinates in the Input volume
|
||||
int64_t i_w_base = ow * args.s0 - args.p0;
|
||||
int64_t i_h_base = oh * args.s1 - args.p1;
|
||||
int64_t i_d_base = od * args.s2 - args.p2;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
|
||||
for (int64_t ic = 0; ic < args.IC; ++ic) {
|
||||
|
||||
// ggml packs batch and channel together in the 4th dimension
|
||||
int64_t src_cn_idx = batch_idx * args.IC + ic;
|
||||
int64_t w_cn_idx = oc * args.IC + ic;
|
||||
|
||||
for (int64_t kz = 0; kz < args.KD; ++kz) {
|
||||
int64_t id = i_d_base + kz * args.d2;
|
||||
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
|
||||
|
||||
for (int64_t ky = 0; ky < args.KH; ++ky) {
|
||||
int64_t ih = i_h_base + ky * args.d1;
|
||||
if (ih < 0 || ih >= args.IH) continue;
|
||||
|
||||
for (int64_t kx = 0; kx < args.KW; ++kx) {
|
||||
int64_t iw = i_w_base + kx * args.d0;
|
||||
if (iw < 0 || iw >= args.IW) continue;
|
||||
|
||||
// Convert multi-dimensional coordinates to flat byte offsets
|
||||
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
|
||||
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
|
||||
|
||||
// Dereference memory and cast weights to f32 if they were f16
|
||||
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
|
||||
float i_val = *(device const float*)((device const char*)src1 + i_idx);
|
||||
|
||||
sum += w_val * i_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Write the accumulated value out to RAM
|
||||
int64_t dst_cn_idx = batch_idx * args.OC + oc;
|
||||
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
|
||||
|
||||
*(device float*)(dst + d_idx) = sum;
|
||||
}
|
||||
|
||||
// Explicit instantiations so the JIT compiler can find them by name
|
||||
template [[host_name("kernel_conv_3d_f32_f32")]]
|
||||
kernel void kernel_conv_3d<float>(
|
||||
constant ggml_metal_kargs_conv_3d & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]);
|
||||
|
||||
// Explicit instantiation for f16 weights
|
||||
template [[host_name("kernel_conv_3d_f16_f32")]]
|
||||
kernel void kernel_conv_3d<half>(
|
||||
constant ggml_metal_kargs_conv_3d & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]);
|
||||
@@ -0,0 +1,686 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#define GGML_COMMON_DECL_METAL
|
||||
#define GGML_COMMON_IMPL_METAL
|
||||
#if defined(GGML_METAL_EMBED_LIBRARY)
|
||||
__embed_ggml-common.h__
|
||||
#else
|
||||
#include "ggml-common.h"
|
||||
#endif
|
||||
|
||||
#define QK_NL 16 // shared by mul_mm and get_rows_q instantiations
|
||||
|
||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||
template <typename type4x4>
|
||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*(src));
|
||||
}
|
||||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template <typename type4x4>
|
||||
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*(src));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * qs = xb->qs;
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
|
||||
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
|
||||
const uint8_t b0 = qs[byte_offset];
|
||||
const uint8_t b1 = qs[byte_offset + 1];
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
|
||||
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
|
||||
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
|
||||
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
|
||||
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
|
||||
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
|
||||
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
|
||||
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
|
||||
|
||||
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
|
||||
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
|
||||
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
|
||||
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
|
||||
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
|
||||
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
|
||||
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
|
||||
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
const int base = il * 4;
|
||||
const uint8_t byte = xb->qs[base / 8];
|
||||
const int s = base % 8;
|
||||
|
||||
float4 reg_f;
|
||||
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
|
||||
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
|
||||
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
|
||||
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
|
||||
|
||||
reg = (type4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float md = -8.h * xb->d;
|
||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
|
||||
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float md = -8.h * xb->d;
|
||||
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
|
||||
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float m = xb->m;
|
||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
|
||||
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float m = xb->m;
|
||||
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
|
||||
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
||||
const float d = xb->d;
|
||||
const float md = -16.h * xb->d;
|
||||
const ushort mask = il ? 0x00F0 : 0x000F;
|
||||
|
||||
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||
|
||||
const int x_mv = il ? 4 : 0;
|
||||
|
||||
const int gh_mv = il ? 12 : 0;
|
||||
const int gh_bk = il ? 0 : 4;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||
|
||||
// combine the 4-bits from qs with the 5th bit
|
||||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
|
||||
reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
||||
const float d = xb->d;
|
||||
const float md = -16.h * xb->d;
|
||||
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
|
||||
|
||||
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||
|
||||
const int x_mv = (il/4) ? 4 : 0;
|
||||
|
||||
const int gh_mv = (il/4) ? 12 : 0;
|
||||
const int gh_bk = (il/4) ? 0 : 4;
|
||||
|
||||
for (int ii = 0; ii < 2; ii++) {
|
||||
int i = 2*(il%4) + ii;
|
||||
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||
|
||||
// combine the 4-bits from qs with the 5th bit
|
||||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg[2*ii + 0] = d * x0 + md;
|
||||
reg[2*ii + 1] = d * x1 + md;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
||||
const float d = xb->d;
|
||||
const float m = xb->m;
|
||||
const ushort mask = il ? 0x00F0 : 0x000F;
|
||||
|
||||
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||
|
||||
const int x_mv = il ? 4 : 0;
|
||||
|
||||
const int gh_mv = il ? 12 : 0;
|
||||
const int gh_bk = il ? 0 : 4;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||
|
||||
// combine the 4-bits from qs with the 5th bit
|
||||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
|
||||
reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
||||
const float d = xb->d;
|
||||
const float m = xb->m;
|
||||
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
|
||||
|
||||
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||
|
||||
const int x_mv = (il/4) ? 4 : 0;
|
||||
|
||||
const int gh_mv = (il/4) ? 12 : 0;
|
||||
const int gh_bk = (il/4) ? 0 : 4;
|
||||
|
||||
for (int ii = 0; ii < 2; ii++) {
|
||||
int i = 2*(il%4) + ii;
|
||||
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||
|
||||
// combine the 4-bits from qs with the 5th bit
|
||||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg[2*ii + 0] = d * x0 + m;
|
||||
reg[2*ii + 1] = d * x1 + m;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
const float d = xb->d;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg_f[i/4][i%4] = (qs[i + 16*il] * d);
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
|
||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
const float d = xb->d;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
|
||||
|
||||
const float d = e8m0_to_fp32(xb->e);
|
||||
const uint8_t shr = il >= 1 ? 4 : 0;
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
|
||||
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
|
||||
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
|
||||
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
|
||||
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
|
||||
|
||||
const float d = e8m0_to_fp32(xb->e);
|
||||
const short il4 = il%4;
|
||||
|
||||
const uint8_t shr = il >= 4 ? 4 : 0;
|
||||
|
||||
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
|
||||
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
|
||||
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
|
||||
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||
const float d = xb->d;
|
||||
const float min = xb->dmin;
|
||||
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
||||
float dl, ml;
|
||||
uint8_t sc = xb->scales[il];
|
||||
|
||||
q = q + 32*(il/8) + 16*(il&1);
|
||||
il = (il/2)%4;
|
||||
|
||||
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
||||
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
||||
const half d_all = xb->d;
|
||||
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
||||
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||
|
||||
q = q + 32 * (il/8) + 16 * (il&1);
|
||||
h = h + 16 * (il&1);
|
||||
uint8_t m = 1 << (il/2);
|
||||
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
|
||||
((il/4)>0 ? 12 : 3);
|
||||
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
||||
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
||||
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
||||
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
||||
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
||||
const float ml = 4.f * dl;
|
||||
|
||||
il = (il/2) & 3;
|
||||
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
||||
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
dl *= coef;
|
||||
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
||||
}
|
||||
}
|
||||
|
||||
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
||||
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
||||
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
|
||||
device const uchar * q = xb->qs;
|
||||
|
||||
short is = (il/4) * 2;
|
||||
q = q + (il/4) * 32 + 16 * (il&1);
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const float min = xb->dmin;
|
||||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
|
||||
const ushort mask = il < 2 ? 0x0F : 0xF0;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * q = xb->qs;
|
||||
device const uint8_t * qh = xb->qh;
|
||||
|
||||
short is = (il/4) * 2;
|
||||
q = q + 32 * (il/4) + 16 * (il&1);
|
||||
qh = qh + 16 * (il&1);
|
||||
uint8_t ul = 1 << (il/2);
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const float d = il < 2 ? xb->d : xb->d / 16.f;
|
||||
const float min = xb->dmin;
|
||||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
|
||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||
const float qh_val = il<2 ? 16.f : 256.f;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
||||
const half d_all = xb->d;
|
||||
device const uint16_t * ql = (device const uint16_t *)xb->ql;
|
||||
device const uint16_t * qh = (device const uint16_t *)xb->qh;
|
||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||
|
||||
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
|
||||
qh = qh + 16*(il/8) + 8*(il&1);
|
||||
float sc = scales[(il%2) + 2 * ((il/2))];
|
||||
il = (il/2) & 3;
|
||||
|
||||
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
|
||||
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
|
||||
const float ml = d_all * sc * 32.f;
|
||||
const float dl0 = d_all * sc;
|
||||
const float dl1 = dl0 / 256.f;
|
||||
const float dl2 = dl0 / (256.f * 256.f);
|
||||
const float dl3 = dl0 / (256.f * 256.f * 256.f);
|
||||
const uint8_t shr_h = il>2 ? 2 : 0;
|
||||
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
|
||||
const uint8_t shr_l = il>1 ? 4 : 0;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
|
||||
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
|
||||
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
|
||||
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
|
||||
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
|
||||
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
|
||||
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
||||
device const uint16_t * q2 = xb->qs + 4*ib32;
|
||||
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
||||
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
||||
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
||||
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
||||
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||
}
|
||||
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
||||
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint16_t * q2 = xb->qs + 4*ib32;
|
||||
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
||||
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
|
||||
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||
}
|
||||
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
|
||||
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint8_t * q3 = xb->qs + 8*ib32;
|
||||
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
|
||||
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
||||
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
|
||||
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
||||
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
||||
}
|
||||
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
|
||||
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
|
||||
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
||||
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint8_t * qs = xb->qs + 8*ib32;
|
||||
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
||||
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
||||
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
||||
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
||||
}
|
||||
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
||||
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
||||
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||
device const uint8_t * signs = qs + QK_K/8;
|
||||
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
||||
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
|
||||
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
const float d = xb->d;
|
||||
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||
device const uint16_t * qh = xb->qh;
|
||||
const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
|
||||
const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
|
||||
const uint16_t h = qh[ib32] >> 6*il;
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * (grid1[i] & 0xf) + ml;
|
||||
reg[1][i] = dl * (grid1[i] >> 4) + ml;
|
||||
reg[2][i] = dl * (grid2[i] & 0xf) + ml;
|
||||
reg[3][i] = dl * (grid2[i] >> 4) + ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
||||
|
||||
iq1m_scale_t scale;
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
const float d = scale.f16;
|
||||
|
||||
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
||||
|
||||
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
||||
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
|
||||
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
|
||||
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
|
||||
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
||||
const float d = xb->d;
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
|
||||
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
||||
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
||||
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
||||
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
||||
const float d = xb->d;
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
|
||||
reg[0] = d * kvalues_iq4nl_f[q8[0]];
|
||||
reg[1] = d * kvalues_iq4nl_f[q8[1]];
|
||||
reg[2] = d * kvalues_iq4nl_f[q8[2]];
|
||||
reg[3] = d * kvalues_iq4nl_f[q8[3]];
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
|
||||
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
|
||||
const float d = (float)xb->d * (ls - 32);
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
|
||||
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
||||
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
||||
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
||||
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,250 @@
|
||||
#include "common.h"
|
||||
|
||||
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
|
||||
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
|
||||
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
|
||||
|
||||
#if 1
|
||||
template<short NSG>
|
||||
kernel void kernel_gated_delta_net_impl(
|
||||
constant ggml_metal_kargs_gated_delta_net & args,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * g,
|
||||
device const char * b,
|
||||
device const char * s,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
#define S_v FC_gated_delta_net_ne20
|
||||
#define G FC_gated_delta_net_ne30
|
||||
#define K FC_gated_delta_net_K
|
||||
|
||||
const uint tx = tpitg.x;
|
||||
const uint ty = tpitg.y;
|
||||
|
||||
const uint i23 = tgpig.z; // B (n_seqs)
|
||||
const uint i21 = tgpig.y; // H (head)
|
||||
const uint i20 = tgpig.x*NSG + ty; // row within S_v
|
||||
|
||||
const uint i01 = i21 % args.ne01;
|
||||
const uint i11 = i21 % args.ne11;
|
||||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
|
||||
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
|
||||
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
device const float * s_ptr = (device const float *) (s) + state_in_base;
|
||||
|
||||
float ls[NSG];
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] = s_ptr[is];
|
||||
}
|
||||
|
||||
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
||||
|
||||
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
||||
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
||||
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
||||
|
||||
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
||||
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
||||
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
|
||||
// output state base offset: after attention scores
|
||||
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
|
||||
// output state per-slot size: S_v * S_v * H * n_seqs
|
||||
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
|
||||
// per-(seq,head) offset within a slot
|
||||
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
|
||||
for (short t = 0; t < args.ne22; t++) {
|
||||
float s_k = 0.0f;
|
||||
|
||||
if (G == 1) {
|
||||
const float g_exp = exp(g_ptr[0]);
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] *= g_exp;
|
||||
|
||||
s_k += ls[j]*k_ptr[is];
|
||||
}
|
||||
} else {
|
||||
// KDA
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] *= exp(g_ptr[is]);
|
||||
|
||||
s_k += ls[j]*k_ptr[is];
|
||||
}
|
||||
}
|
||||
|
||||
s_k = simd_sum(s_k);
|
||||
|
||||
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
||||
|
||||
float y = 0.0f;
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] += k_ptr[is]*d;
|
||||
|
||||
y += ls[j]*q_ptr[is];
|
||||
}
|
||||
|
||||
y = simd_sum(y);
|
||||
|
||||
if (tx == 0) {
|
||||
dst_attn[t*args.ne21*S_v] = y*scale;
|
||||
}
|
||||
|
||||
q_ptr += args.ns02;
|
||||
k_ptr += args.ns12;
|
||||
v_ptr += args.ns22;
|
||||
|
||||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
if (K > 1) {
|
||||
const int target_slot = (int)args.ne22 - 1 - (int)t;
|
||||
if (target_slot >= 0 && target_slot < (int)K) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (K == 1) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
}
|
||||
}
|
||||
|
||||
#undef S_v
|
||||
#undef G
|
||||
#undef K
|
||||
}
|
||||
|
||||
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
|
||||
|
||||
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
|
||||
|
||||
#else
|
||||
// a simplified version of the above
|
||||
// no performance improvement, so keep the above version for now
|
||||
|
||||
template<typename T, short NSG>
|
||||
kernel void kernel_gated_delta_net_impl(
|
||||
constant ggml_metal_kargs_gated_delta_net & args,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * g,
|
||||
device const char * b,
|
||||
device const char * s,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
#define S_v FC_gated_delta_net_ne20
|
||||
#define G FC_gated_delta_net_ne30
|
||||
|
||||
const uint tx = tpitg.x;
|
||||
const uint ty = tpitg.y;
|
||||
|
||||
const uint i23 = tgpig.z; // B
|
||||
const uint i21 = tgpig.y; // H
|
||||
const uint i20 = tgpig.x*NSG + ty;
|
||||
|
||||
const uint i01 = i21 % args.ne01;
|
||||
const uint i11 = i21 % args.ne11;
|
||||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
|
||||
float lsf[NSG];
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
lsf[j] = s_ptr[is*S_v];
|
||||
}
|
||||
|
||||
thread T * ls = (thread T *) (lsf);
|
||||
|
||||
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
||||
|
||||
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
||||
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
||||
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
||||
|
||||
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
||||
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
||||
|
||||
for (short t = 0; t < args.ne22; t++) {
|
||||
device const T * qt_ptr = (device const T *) (q_ptr);
|
||||
device const T * kt_ptr = (device const T *) (k_ptr);
|
||||
device const T * gt_ptr = (device const T *) (g_ptr);
|
||||
|
||||
if (G == 1) {
|
||||
*ls *= exp(g_ptr[0]);
|
||||
} else {
|
||||
// KDA
|
||||
*ls *= exp(gt_ptr[tx]);
|
||||
}
|
||||
|
||||
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
|
||||
|
||||
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
||||
|
||||
*ls += kt_ptr[tx]*d;
|
||||
|
||||
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
|
||||
|
||||
if (tx == 0) {
|
||||
*dst_attn = y*scale;
|
||||
}
|
||||
|
||||
q_ptr += args.ns02;
|
||||
k_ptr += args.ns12;
|
||||
v_ptr += args.ns22;
|
||||
|
||||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
dst_attn += args.ne21*S_v;
|
||||
}
|
||||
|
||||
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
device T * dstt_state = (device T *) (dst_state);
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is*S_v] = lsf[j];
|
||||
}
|
||||
|
||||
#undef S_v
|
||||
#undef G
|
||||
}
|
||||
|
||||
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
|
||||
|
||||
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
|
||||
#endif
|
||||
@@ -0,0 +1,347 @@
|
||||
#include "common.h"
|
||||
|
||||
kernel void kernel_argmax_f32(
|
||||
constant ggml_metal_kargs_argmax & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
|
||||
|
||||
float lmax = -INFINITY;
|
||||
int32_t larg = -1;
|
||||
|
||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
||||
if (x_row[i00] > lmax) {
|
||||
lmax = x_row[i00];
|
||||
larg = i00;
|
||||
}
|
||||
}
|
||||
|
||||
// find the argmax value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
|
||||
|
||||
device int32_t * dst_i32 = (device int32_t *) dst;
|
||||
|
||||
threadgroup float * shared_maxval = (threadgroup float *) shmem;
|
||||
threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
shared_maxval[tiisg] = -INFINITY;
|
||||
shared_argmax[tiisg] = -1;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shared_maxval[sgitg] = max_val;
|
||||
shared_argmax[sgitg] = arg_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = shared_maxval[tiisg];
|
||||
arg_val = shared_argmax[tiisg];
|
||||
|
||||
float max_val_reduced = simd_max(max_val);
|
||||
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
|
||||
|
||||
dst_i32[tgpig] = arg_val_reduced;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
dst_i32[tgpig] = arg_val;
|
||||
}
|
||||
|
||||
kernel void kernel_diag_f32(
|
||||
constant ggml_metal_kargs_diag & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
|
||||
const int32_t i3 = tgpig.z;
|
||||
const int32_t i2 = tgpig.y;
|
||||
const int32_t i1 = tgpig.x;
|
||||
|
||||
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
|
||||
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_roll_f32(
|
||||
constant ggml_metal_kargs_roll & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
device const float * src0_ptr = (device const float *) src0;
|
||||
device float * dst_ptr = (device float *) dst;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
// apply shifts and wrap around
|
||||
int64_t i00 = i0 - args.s0;
|
||||
int64_t i01 = i1 - args.s1;
|
||||
int64_t i02 = i2 - args.s2;
|
||||
int64_t i03 = i3 - args.s3;
|
||||
|
||||
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
|
||||
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
|
||||
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
|
||||
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
|
||||
|
||||
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
|
||||
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
|
||||
|
||||
dst_ptr[dst_idx] = src0_ptr[src_idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_pad_impl(
|
||||
constant ggml_metal_kargs_pad & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i3 = tgpig.z;
|
||||
const int32_t i2 = tgpig.y;
|
||||
const int32_t k0 = tgpig.x/args.ne1;
|
||||
const int32_t i1 = tgpig.x - k0*args.ne1;
|
||||
|
||||
const int32_t i03 = i3;
|
||||
const int32_t i02 = i2;
|
||||
const int32_t i01 = i1;
|
||||
|
||||
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
|
||||
const int32_t i0 = k0*1024 + tpitg.x + l0;
|
||||
if (i0 >= args.ne0) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
dst_ptr[i0] = src0_ptr[i0];
|
||||
} else {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
|
||||
|
||||
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
|
||||
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
|
||||
|
||||
// TODO: this is slow - optimize
|
||||
kernel void kernel_pad_reflect_1d_f32(
|
||||
constant ggml_metal_kargs_pad_reflect_1d & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i01 = i1;
|
||||
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
if (i0 < args.p0) {
|
||||
dst_ptr[i0] = src0_ptr[args.p0 - i0];
|
||||
} else if (i0 < args.ne0 - args.p1) {
|
||||
dst_ptr[i0] = src0_ptr[i0 - args.p0];
|
||||
} else {
|
||||
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_arange_f32(
|
||||
constant ggml_metal_kargs_arange & args,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
device float * dst_ptr = (device float *) dst;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
dst_ptr[i0] = args.start + args.step * i0;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_timestep_embedding_f32(
|
||||
constant ggml_metal_kargs_timestep_embedding & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
int i = tgpig.x;
|
||||
device float * embed_data = (device float *)(dst + i*args.nb1);
|
||||
|
||||
int half_ = args.dim / 2;
|
||||
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
||||
float timestep = ((device float *)src0)[i];
|
||||
float freq = (float)exp(-log((float)args.max_period) * j / half_);
|
||||
float arg = timestep * freq;
|
||||
embed_data[j ] = cos(arg);
|
||||
embed_data[j + half_] = sin(arg);
|
||||
}
|
||||
|
||||
if (args.dim % 2 != 0 && tpitg.x == 0) {
|
||||
embed_data[2 * half_] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_adamw_f32(
|
||||
constant ggml_metal_kargs_opt_step_adamw & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device float * g_m,
|
||||
device float * g_v,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float alpha = pars[0];
|
||||
const float beta1 = pars[1];
|
||||
const float beta2 = pars[2];
|
||||
const float eps = pars[3];
|
||||
const float wd = pars[4];
|
||||
const float beta1h = pars[5];
|
||||
const float beta2h = pars[6];
|
||||
|
||||
const float gi = g[gid];
|
||||
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
|
||||
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
|
||||
|
||||
g_m[gid] = gmi;
|
||||
g_v[gid] = gvi;
|
||||
|
||||
const float mh = gmi * beta1h;
|
||||
const float vh = sqrt(gvi * beta2h) + eps;
|
||||
|
||||
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_sgd_f32(
|
||||
constant ggml_metal_kargs_opt_step_sgd & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_memset(
|
||||
constant ggml_metal_kargs_memset & args,
|
||||
device T * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
|
||||
|
||||
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
|
||||
|
||||
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_count_equal(
|
||||
constant ggml_metal_kargs_count_equal & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device atomic_int * dst,
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const short NSG = FC_count_equal_nsg;
|
||||
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int sum = 0;
|
||||
|
||||
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
|
||||
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
const T v0 = *(device const T *)(base0 + i0*args.nb00);
|
||||
const T v1 = *(device const T *)(base1 + i0*args.nb10);
|
||||
sum += (v0 == v1);
|
||||
}
|
||||
|
||||
sum = simd_sum(sum);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_i32[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
float v = 0.0f;
|
||||
if (tpitg.x < NSG) {
|
||||
v = shmem_i32[tpitg.x];
|
||||
}
|
||||
|
||||
float total = simd_sum(v);
|
||||
if (tpitg.x == 0) {
|
||||
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
|
||||
|
||||
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
|
||||
@@ -0,0 +1,838 @@
|
||||
#include "common.h"
|
||||
#include "dequantize.h"
|
||||
|
||||
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
||||
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
||||
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
|
||||
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
|
||||
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
|
||||
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
template<
|
||||
typename SA, typename SA_4x4, typename SA_8x8,
|
||||
typename SB, typename SB_2x4, typename SB_8x8,
|
||||
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
|
||||
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm(
|
||||
constant ggml_metal_kargs_mul_mm & args,
|
||||
device const char * srcA,
|
||||
device const char * srcB,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
ushort tiitg [[thread_index_in_threadgroup]],
|
||||
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
|
||||
(void) sgitg;
|
||||
|
||||
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
|
||||
const int K = args.ne00;
|
||||
const int M = args.ne0;
|
||||
const int N = args.ne1;
|
||||
|
||||
// Batch dimension handling
|
||||
const int im = tgpig.z;
|
||||
const int i12 = im % FC_mul_mm_ne12;
|
||||
const int i13 = im / FC_mul_mm_ne12;
|
||||
|
||||
// Batch offsets for srcA and srcB
|
||||
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
|
||||
|
||||
// Tile dimensions
|
||||
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
|
||||
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
|
||||
|
||||
// Tile offsets in output matrix
|
||||
const int ra = tgpig.y * NRA;
|
||||
const int rb = tgpig.x * NRB;
|
||||
|
||||
// Threadgroup memory for dequantized A tile only
|
||||
threadgroup SA * sa = (threadgroup SA *)(shmem);
|
||||
|
||||
// Work-item count for A loading
|
||||
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
|
||||
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
|
||||
|
||||
// tA wraps threadgroup memory
|
||||
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
|
||||
|
||||
// tB wraps device memory directly
|
||||
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
|
||||
const int strideB = args.nb11 / sizeof(T1);
|
||||
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
|
||||
|
||||
// Configure matmul operation
|
||||
mpp::tensor_ops::matmul2d<
|
||||
mpp::tensor_ops::matmul2d_descriptor(
|
||||
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
|
||||
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
||||
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
|
||||
|
||||
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
|
||||
|
||||
// Accumulate partial results over K dimension
|
||||
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
|
||||
// === PHASE 1: Dequantization of A into threadgroup memory ===
|
||||
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
|
||||
const int row = work / N_MM_NK;
|
||||
const int k_chunk = work % N_MM_NK;
|
||||
const int k_pos = loop_k + k_chunk * 16;
|
||||
const short k_base = k_chunk * 16;
|
||||
|
||||
// Bounds check: skip device read if row is out of matrix bounds
|
||||
if (ra + row < M) {
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
|
||||
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
|
||||
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
|
||||
// Mirrors the legacy kernel's existing guard.
|
||||
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
|
||||
}
|
||||
} else {
|
||||
const int block_idx = k_pos / (16 * nl);
|
||||
const short il = (k_pos / 16) % nl;
|
||||
|
||||
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
|
||||
|
||||
SA_4x4 temp_a;
|
||||
dequantize_func(row_ptr + block_idx, il, temp_a);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
|
||||
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Zero-pad rows beyond matrix bounds
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// === PHASE 2: Tensor matmul ===
|
||||
auto mA = tA.slice(0, 0);
|
||||
auto mB = tB.slice(loop_k, rb);
|
||||
|
||||
mm.run(mB, mA, cT);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// Store result tile to output matrix (with batch offset)
|
||||
// cT.store handles bounds checking via tD's extents (M, N)
|
||||
device float * dstBatch = (device float *)dst + im * N * M;
|
||||
|
||||
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
|
||||
cT.store(tD.slice(ra, rb));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template<
|
||||
typename S0, typename S0_4x4, typename S0_8x8,
|
||||
typename S1, typename S1_2x4, typename S1_8x8,
|
||||
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
|
||||
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm(
|
||||
constant ggml_metal_kargs_mul_mm & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
|
||||
constexpr int NK = 32;
|
||||
constexpr int NL0 = NK/16;
|
||||
constexpr int NL1 = NK/8;
|
||||
|
||||
const int im = tgpig.z;
|
||||
const int r0 = tgpig.y*NR0;
|
||||
const int r1 = tgpig.x*NR1;
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
|
||||
const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
|
||||
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
|
||||
|
||||
const short il0 = (tiitg % NL0);
|
||||
|
||||
short il = il0;
|
||||
|
||||
const int i12 = im % FC_mul_mm_ne12;
|
||||
const int i13 = im / FC_mul_mm_ne12;
|
||||
|
||||
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
|
||||
const short offset1 = il0/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
|
||||
|
||||
const short iy = 8*(tiitg % NL1);
|
||||
|
||||
device const T1 * y = (device const T1 *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*(r1 + lr1)
|
||||
+ args.nb10*iy);
|
||||
|
||||
S0_8x8 ma[4];
|
||||
S1_8x8 mb[2];
|
||||
|
||||
simdgroup_float8x8 mc[8];
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
}
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
||||
// load data and store to threadgroup memory
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
//const short lx = i%8;
|
||||
//const short ly = (tiitg/NL0)%8;
|
||||
const short lx = (tiitg/NL0)%8;
|
||||
const short ly = i%8;
|
||||
|
||||
const short ib = 8*sx + sy;
|
||||
|
||||
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
//const short lx = i%8;
|
||||
//const short ly = (tiitg/NL0)%8;
|
||||
const short lx = (tiitg/NL0)%8;
|
||||
const short ly = i%8;
|
||||
|
||||
const short ib = 8*sx + sy;
|
||||
|
||||
// NOTE: this is massively slower.. WTF?
|
||||
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
|
||||
|
||||
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
const short ib = 4*sx + sy;
|
||||
|
||||
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
||||
}
|
||||
} else {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
const short ib = 4*sx + sy;
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
|
||||
y += NK;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
||||
|
||||
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 4; i++) {
|
||||
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 2; i++) {
|
||||
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||
}
|
||||
|
||||
lsma += 8*64;
|
||||
lsmb += 4*64;
|
||||
}
|
||||
}
|
||||
|
||||
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
|
||||
// if no bounds checks on the output are needed, we can directly write to device memory
|
||||
device float * C = (device float *) dst +
|
||||
(r0 + 32*(sgitg & 1)) + \
|
||||
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
|
||||
}
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int j = tiitg; j < nr1; j += NR1) {
|
||||
device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = temp_str + (j*NR0);
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = 0;
|
||||
for (; i < nr0/4; i++) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
i *= 4;
|
||||
for (; i < nr0; i++) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_METAL_HAS_TENSOR
|
||||
|
||||
template<short ne20> // n_expert_used
|
||||
kernel void kernel_mul_mm_id_map0(
|
||||
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
||||
device const char * src2,
|
||||
device char * htpe,
|
||||
device char * hids,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const short ide = tpitg; // expert id
|
||||
|
||||
uint32_t n_all = 0;
|
||||
|
||||
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
|
||||
|
||||
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
|
||||
if (i21 + tpitg < args.ne21) {
|
||||
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
|
||||
|
||||
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
|
||||
|
||||
#pragma unroll(ne20)
|
||||
for (short i20 = 0; i20 < ne20; i20++) {
|
||||
sids[i20] = src2_i32[i20];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short t = 0; t < ntg; t++) {
|
||||
if (i21 + t >= args.ne21) {
|
||||
break;
|
||||
}
|
||||
|
||||
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
|
||||
|
||||
short sel = 0;
|
||||
#pragma unroll(ne20)
|
||||
for (short i20 = 0; i20 < ne20; i20++) {
|
||||
sel += (sids[i20] == ide)*(i20 + 1);
|
||||
}
|
||||
|
||||
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
|
||||
|
||||
n_all += sel > 0;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
|
||||
tpe_u32[ide] = n_all;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * htpe,
|
||||
device const char * hids,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||
#endif
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
|
||||
constexpr int NK = 32;
|
||||
constexpr int NL0 = NK/16;
|
||||
constexpr int NL1 = NK/8;
|
||||
|
||||
const int im = tgpig.z; // expert
|
||||
const int r0 = tgpig.y*NR0;
|
||||
const int r1 = tgpig.x*NR1;
|
||||
|
||||
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
|
||||
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
||||
|
||||
const int32_t neh1 = tpe_u32[im];
|
||||
|
||||
if (r1 >= neh1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
|
||||
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
|
||||
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
|
||||
|
||||
const short il0 = (tiitg % NL0);
|
||||
|
||||
short il = il0;
|
||||
|
||||
const int id = ids_i32[im*args.ne21 + r1 + lr1];
|
||||
|
||||
const short i11 = (id % args.ne20) % args.ne11;
|
||||
const short i12 = (id / args.ne20);
|
||||
const short i13 = 0;
|
||||
|
||||
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
|
||||
const short offset1 = il0/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
|
||||
|
||||
const short iy = 8*(tiitg % NL1);
|
||||
|
||||
device const T1 * y = (device const T1 *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*i11
|
||||
+ args.nb10*iy);
|
||||
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
S0_8x8 ma[4];
|
||||
S1_8x8 mb[2];
|
||||
|
||||
simdgroup_float8x8 mc[8];
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
}
|
||||
#else
|
||||
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
|
||||
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
|
||||
|
||||
mpp::tensor_ops::matmul2d<
|
||||
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
||||
execution_simdgroups<4>> mm;
|
||||
|
||||
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
|
||||
#endif
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
// load data and store to threadgroup memory
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
//const short lx = i%8;
|
||||
//const short ly = (tiitg/NL0)%8;
|
||||
const short lx = (tiitg/NL0)%8;
|
||||
const short ly = i%8;
|
||||
|
||||
const short ib = 8*sx + sy;
|
||||
|
||||
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
//const short lx = i%8;
|
||||
//const short ly = (tiitg/NL0)%8;
|
||||
const short lx = (tiitg/NL0)%8;
|
||||
const short ly = i%8;
|
||||
|
||||
const short ib = 8*sx + sy;
|
||||
|
||||
// NOTE: this is massively slower.. WTF?
|
||||
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
|
||||
|
||||
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
const short ib = 4*sx + sy;
|
||||
|
||||
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
||||
}
|
||||
} else {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
const short ib = 4*sx + sy;
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
#else
|
||||
// load data and store to threadgroup memory
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
const short lx = i%8;
|
||||
const short ly = (tiitg/NL0)%8;
|
||||
//const short lx = (tiitg/NL0)%8;
|
||||
//const short ly = i%8;
|
||||
|
||||
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
const short lx = i%8;
|
||||
const short ly = (tiitg/NL0)%8;
|
||||
//const short lx = (tiitg/NL0)%8;
|
||||
//const short ly = i%8;
|
||||
|
||||
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
||||
}
|
||||
} else {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
//const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
#endif
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
|
||||
y += NK;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
||||
|
||||
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 4; i++) {
|
||||
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 2; i++) {
|
||||
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||
}
|
||||
|
||||
lsma += 8*64;
|
||||
lsmb += 4*64;
|
||||
}
|
||||
#else
|
||||
auto sA = tA.slice(0, 0);
|
||||
auto sB = tB.slice(0, 0);
|
||||
|
||||
mm.run(sB, sA, cT);
|
||||
#endif
|
||||
}
|
||||
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
|
||||
cT.store(tC);
|
||||
#else
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
||||
}
|
||||
#endif
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short j = sgitg; j < nr1; j += 4) {
|
||||
const int id = ids_i32[im*args.ne21 + r1 + j];
|
||||
|
||||
const short ide = id % args.ne20;
|
||||
const short idt = id / args.ne20;
|
||||
|
||||
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = tiisg;
|
||||
for (; i < nr0/4; i += 32) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
i = (4*(nr0/4)) + tiisg;
|
||||
for (; i < nr0; i += 32) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,308 @@
|
||||
#include "common.h"
|
||||
|
||||
// F == 1 : norm (no fuse)
|
||||
// F == 2 : norm + mul
|
||||
// F == 3 : norm + mul + add
|
||||
template <typename T, short F>
|
||||
kernel void kernel_norm_fuse_impl(
|
||||
constant ggml_metal_kargs_norm & args,
|
||||
device const char * src0,
|
||||
device const char * src1_0,
|
||||
device const char * src1_1,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
const int i01 = tgpig.x;
|
||||
const int i02 = tgpig.y;
|
||||
const int i03 = tgpig.z;
|
||||
|
||||
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
||||
|
||||
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
||||
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
||||
|
||||
T sumft(0.0f);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
sumft += x[i00];
|
||||
}
|
||||
sumf = dot(sumft, T(1.0f));
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float mean = sumf/args.ne00;
|
||||
|
||||
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
|
||||
sumf = 0.0f;
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
y[i00] = x[i00] - mean;
|
||||
sumf += dot(y[i00], y[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float variance = sumf/args.ne00;
|
||||
|
||||
const float scale = 1.0f/sqrt(variance + args.eps);
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
if (F == 1) {
|
||||
y[i00] = (y[i00]*scale);
|
||||
}
|
||||
if (F == 2) {
|
||||
y[i00] = (y[i00]*scale)*f0[i00];
|
||||
}
|
||||
if (F == 3) {
|
||||
y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
|
||||
|
||||
template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
|
||||
template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
|
||||
template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
|
||||
|
||||
template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
|
||||
template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
|
||||
template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
|
||||
|
||||
// F == 1 : rms_norm (no fuse)
|
||||
// F == 2 : rms_norm + mul
|
||||
// F == 3 : rms_norm + mul + add
|
||||
template <typename T, short F>
|
||||
kernel void kernel_rms_norm_fuse_impl(
|
||||
constant ggml_metal_kargs_norm & args,
|
||||
device const char * src0,
|
||||
device const char * src1_0,
|
||||
device const char * src1_1,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
const int i01 = tgpig.x;
|
||||
const int i02 = tgpig.y;
|
||||
const int i03 = tgpig.z;
|
||||
|
||||
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
||||
|
||||
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
||||
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float mean = sumf/args.ne00;
|
||||
const float scale = 1.0f/sqrt(mean + args.eps);
|
||||
|
||||
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
if (F == 1) {
|
||||
y[i00] = (x[i00]*scale);
|
||||
}
|
||||
if (F == 2) {
|
||||
y[i00] = (x[i00]*scale)*f0[i00];
|
||||
}
|
||||
if (F == 3) {
|
||||
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
|
||||
|
||||
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
|
||||
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
|
||||
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
|
||||
|
||||
template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
|
||||
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
|
||||
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
|
||||
|
||||
template <typename T0, typename T>
|
||||
kernel void kernel_l2_norm_impl(
|
||||
constant ggml_metal_kargs_l2_norm & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float scale = 1.0f/max(sqrt(sumf), args.eps);
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
|
||||
|
||||
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
|
||||
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t ne = args.ne00*args.ne01*args.ne02;
|
||||
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
|
||||
|
||||
int start = tgpig * gs;
|
||||
int end = start + gs;
|
||||
|
||||
start += tpitg;
|
||||
|
||||
if (end >= ne) {
|
||||
end = ne;
|
||||
}
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int j = start; j < end; j += ntg) {
|
||||
tmp += src0[j];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
tmp = simd_sum(tmp);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = tmp;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
tmp = buf[tiisg];
|
||||
tmp = simd_sum(tmp);
|
||||
}
|
||||
|
||||
const float mean = tmp / gs;
|
||||
tmp = 0.0f;
|
||||
|
||||
for (int j = start; j < end; j += ntg) {
|
||||
float xi = src0[j] - mean;
|
||||
dst[j] = xi;
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
tmp = simd_sum(tmp);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = tmp;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
tmp = buf[tiisg];
|
||||
tmp = simd_sum(tmp);
|
||||
}
|
||||
|
||||
const float variance = tmp / gs;
|
||||
const float scale = 1.0f/sqrt(variance + args.eps);
|
||||
for (int j = start; j < end; j += ntg) {
|
||||
dst[j] *= scale;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
#include "common.h"
|
||||
|
||||
kernel void kernel_pool_2d_max_f32(
|
||||
constant ggml_metal_kargs_pool_2d & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = args.IH * args.IW;
|
||||
const int O_HW = args.OH * args.OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / args.OW;
|
||||
const int cur_ow = idx % O_HW % args.OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * args.s1 - args.p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(args.IH, start_h + args.k1);
|
||||
const int start_w = cur_ow * args.s0 - args.p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(args.IW, start_w + args.k0);
|
||||
|
||||
float res = -INFINITY;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
res = MAX(res, i_ptr[i * args.IW + j]);
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_2d_avg_f32(
|
||||
constant ggml_metal_kargs_pool_2d & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = args.IH * args.IW;
|
||||
const int O_HW = args.OH * args.OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / args.OW;
|
||||
const int cur_ow = idx % O_HW % args.OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * args.s1 - args.p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(args.IH, start_h + args.k1);
|
||||
const int start_w = cur_ow * args.s0 - args.p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(args.IW, start_w + args.k0);
|
||||
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
||||
const float scale = 1. / (args.k0 * args.k1);
|
||||
|
||||
float res = 0;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
float cur = i_ptr[i * args.IW + j];
|
||||
res += cur * scale;
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
|
||||
kernel void kernel_pool_1d_max_f32(
|
||||
constant ggml_metal_kargs_pool_1d & args,
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid [[thread_position_in_grid]]
|
||||
) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ow = (int)gid % args.OW;
|
||||
const int row = (int)gid / args.OW;
|
||||
|
||||
const int base = ow * args.s0 - args.p0;
|
||||
|
||||
float acc = -INFINITY;
|
||||
|
||||
const int src_off = row * args.IW;
|
||||
const int dst_off = row * args.OW;
|
||||
|
||||
for (int ki = 0; ki < args.k0; ++ki) {
|
||||
int j = base + ki;
|
||||
if (j < 0 || j >= args.IW){
|
||||
continue;
|
||||
}
|
||||
float v = src[src_off + j];
|
||||
acc = max(acc, v);
|
||||
}
|
||||
|
||||
dst[dst_off + ow] = acc;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_1d_avg_f32(
|
||||
constant ggml_metal_kargs_pool_1d & args,
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid [[thread_position_in_grid]]
|
||||
) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ow = (int)gid % args.OW;
|
||||
const int row = (int)gid / args.OW;
|
||||
|
||||
const int base = ow * args.s0 - args.p0;
|
||||
|
||||
float acc = 0.0f;
|
||||
int cnt = 0;
|
||||
|
||||
const int src_off = row * args.IW;
|
||||
const int dst_off = row * args.OW;
|
||||
|
||||
for (int ki = 0; ki < args.k0; ++ki) {
|
||||
const int j = base + ki;
|
||||
if (j < 0 || j >= args.IW) {
|
||||
continue;
|
||||
}
|
||||
acc += src[src_off + j];
|
||||
cnt += 1;
|
||||
}
|
||||
|
||||
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
|
||||
float sum_abs = 0.0f;
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
sum_abs += fabs(src[j]);
|
||||
}
|
||||
dst.d = sum_abs / QK1_0;
|
||||
|
||||
for (int j = 0; j < QK1_0 / 8; j++) {
|
||||
dst.qs[j] = 0;
|
||||
}
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
if (src[j] >= 0.0f) {
|
||||
dst.qs[j / 8] |= (1 << (j % 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
|
||||
for (int j = 0; j < QK4_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK4_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
||||
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
||||
|
||||
dst.qs[j] = xi0;
|
||||
dst.qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
|
||||
for (int j = 0; j < QK4_1; j++) {
|
||||
const float v = src[j];
|
||||
if (min > v) min = v;
|
||||
if (max < v) max = v;
|
||||
}
|
||||
|
||||
const float d = (max - min) / ((1 << 4) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
dst.m = min;
|
||||
|
||||
for (int j = 0; j < QK4_1/2; ++j) {
|
||||
const float x0 = (src[0 + j] - min)*id;
|
||||
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
||||
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
||||
|
||||
dst.qs[j] = xi0;
|
||||
dst.qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK5_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -16;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK5_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
||||
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
||||
|
||||
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||
}
|
||||
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst.qh[j] = qh8[j];
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float max = src[0];
|
||||
float min = src[0];
|
||||
|
||||
for (int j = 1; j < QK5_1; j++) {
|
||||
const float v = src[j];
|
||||
min = v < min ? v : min;
|
||||
max = v > max ? v : max;
|
||||
}
|
||||
|
||||
const float d = (max - min) / 31;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
dst.m = min;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_1/2; ++j) {
|
||||
const float x0 = (src[0 + j] - min)*id;
|
||||
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||
|
||||
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||
}
|
||||
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst.qh[j] = qh8[j];
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const float v = src[j];
|
||||
amax = MAX(amax, fabs(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
|
||||
for (int j = 0; j < QK8_0; ++j) {
|
||||
const float x0 = src[j]*id;
|
||||
|
||||
dst.qs[j] = round(x0);
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_NL; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / kvalues_iq4nl_f[0];
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK4_NL/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
||||
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
||||
|
||||
dst.qs[j] = xi0 | (xi1 << 4);
|
||||
|
||||
const float v0 = kvalues_iq4nl_f[xi0];
|
||||
const float v1 = kvalues_iq4nl_f[xi1];
|
||||
const float w0 = src[0 + j]*src[0 + j];
|
||||
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
||||
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
||||
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||
|
||||
}
|
||||
|
||||
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||
}
|
||||
@@ -0,0 +1,389 @@
|
||||
#include "common.h"
|
||||
#include "dequantize.h"
|
||||
#include "quantize.h"
|
||||
|
||||
template<typename T0, typename T1>
|
||||
kernel void kernel_cpy_t_t(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
|
||||
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
|
||||
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
dst_data[i00] = (T1) src[0];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
|
||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
|
||||
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
|
||||
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
|
||||
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
|
||||
#endif
|
||||
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
|
||||
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
|
||||
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
|
||||
#endif
|
||||
|
||||
template<short QK,
|
||||
typename block_q,
|
||||
void (*quantize_func)(device const float *, device block_q &)>
|
||||
kernel void kernel_cpy_f32_q(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
||||
|
||||
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
||||
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
|
||||
|
||||
quantize_func(src, dst_data[i00]);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
|
||||
|
||||
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
kernel void kernel_cpy_q_f32(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
|
||||
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
||||
T4x4 temp;
|
||||
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
||||
dst_data[i00] = temp;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_concat(
|
||||
constant ggml_metal_kargs_concat & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
|
||||
|
||||
if (i1 >= args.ne1) {
|
||||
return;
|
||||
}
|
||||
|
||||
int o[4] = {0, 0, 0, 0};
|
||||
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
device const T * x;
|
||||
|
||||
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
x = (device const T *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
|
||||
} else {
|
||||
x = (device const T *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
|
||||
}
|
||||
|
||||
device T * y = (device T *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_concat<float>) kernel_concat_t;
|
||||
|
||||
template [[host_name("kernel_concat_f32")]] kernel kernel_concat_t kernel_concat<float>;
|
||||
template [[host_name("kernel_concat_f16")]] kernel kernel_concat_t kernel_concat<half>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_concat_bf16")]] kernel kernel_concat_t kernel_concat<bfloat>;
|
||||
#endif
|
||||
template [[host_name("kernel_concat_i8")]] kernel kernel_concat_t kernel_concat<char>;
|
||||
template [[host_name("kernel_concat_i16")]] kernel kernel_concat_t kernel_concat<short>;
|
||||
template [[host_name("kernel_concat_i32")]] kernel kernel_concat_t kernel_concat<int>;
|
||||
template [[host_name("kernel_concat_i64")]] kernel kernel_concat_t kernel_concat<long>;
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
kernel void kernel_get_rows_q(
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device void * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 ntg [[threads_per_threadgroup]]) {
|
||||
const int32_t iw0 = tgpig.x/args.ne10;
|
||||
const int32_t i10 = tgpig.x%args.ne10;
|
||||
const int32_t i11 = tgpig.y;
|
||||
const int32_t i12 = tgpig.z;
|
||||
|
||||
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
|
||||
|
||||
const int32_t i02 = i11;
|
||||
const int32_t i03 = i12;
|
||||
|
||||
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
|
||||
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
|
||||
|
||||
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
|
||||
float4x4 temp;
|
||||
dequantize_func(psrc + ind/nl, ind%nl, temp);
|
||||
pdst[ind] = temp;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T0, typename T>
|
||||
kernel void kernel_get_rows_f(
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device void * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 ntg [[threads_per_threadgroup]]) {
|
||||
const int32_t iw0 = tgpig.x/args.ne10;
|
||||
const int32_t i10 = tgpig.x%args.ne10;
|
||||
const int32_t i11 = tgpig.y;
|
||||
const int32_t i12 = tgpig.z;
|
||||
|
||||
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
|
||||
|
||||
const int32_t i02 = i11;
|
||||
const int32_t i03 = i12;
|
||||
|
||||
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
|
||||
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
|
||||
|
||||
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
|
||||
pdst[ind] = psrc[ind];
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
||||
kernel void kernel_set_rows_q32(
|
||||
constant ggml_metal_kargs_set_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
|
||||
const int32_t i12 = i03%args.ne12;
|
||||
const int32_t i11 = i02%args.ne11;
|
||||
|
||||
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t i10 = i01;
|
||||
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
||||
|
||||
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
|
||||
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
||||
quantize_func(src_row + 32*ind, dst_row[ind]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename TI>
|
||||
kernel void kernel_set_rows_f(
|
||||
constant ggml_metal_kargs_set_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
|
||||
const int32_t i12 = i03%args.ne12;
|
||||
const int32_t i11 = i02%args.ne11;
|
||||
|
||||
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t i10 = i01;
|
||||
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
||||
|
||||
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
|
||||
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
||||
dst_row[ind] = (T) src_row[ind];
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// get rows
|
||||
//
|
||||
|
||||
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
|
||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
|
||||
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
|
||||
#endif
|
||||
|
||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
|
||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// set rows
|
||||
//
|
||||
|
||||
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
|
||||
|
||||
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
|
||||
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
|
||||
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
|
||||
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
|
||||
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
|
||||
#endif
|
||||
|
||||
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
||||
|
||||
template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
|
||||
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
|
||||
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
|
||||
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
|
||||
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
|
||||
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
|
||||
@@ -0,0 +1,228 @@
|
||||
#include "common.h"
|
||||
|
||||
kernel void kernel_op_sum_f32(
|
||||
constant ggml_metal_kargs_sum & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
if (args.np == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: become function constant
|
||||
const uint nsg = (ntg.x + 31) / 32;
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
|
||||
sumf += src0[i0];
|
||||
}
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float total = 0;
|
||||
|
||||
if (sgitg == 0) {
|
||||
float v = 0;
|
||||
|
||||
if (tpitg.x < nsg) {
|
||||
v = shmem_f32[tpitg.x];
|
||||
}
|
||||
|
||||
total = simd_sum(v);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst[0] = total;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
|
||||
|
||||
template <typename T0, typename T>
|
||||
kernel void kernel_sum_rows_impl(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
#define FC_OP FC_sum_rows_op
|
||||
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_t[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
T0 sumf = T0(0.0f);
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_t[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_t[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
|
||||
if (is_same<float4, T0>::value) {
|
||||
dst_row[0] = sum(sumf) / (4*args.ne00);
|
||||
} else {
|
||||
dst_row[0] = sum(sumf) / args.ne00;
|
||||
}
|
||||
} else {
|
||||
dst_row[0] = sum(sumf);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
|
||||
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_blk(
|
||||
constant ggml_metal_kargs_cumsum_blk & args,
|
||||
device const char * src0,
|
||||
device char * tmp,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ib = tgpig[0]/args.ne01;
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * src0_row = (device const float *) (src0 +
|
||||
args.nb01*i01 +
|
||||
args.nb02*i02 +
|
||||
args.nb03*i03);
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
|
||||
float v = 0.0f;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
v = src0_row[i00 + tpitg.x];
|
||||
}
|
||||
|
||||
float s = simd_prefix_inclusive_sum(v);
|
||||
|
||||
if (tiisg == N_SIMDWIDTH - 1) {
|
||||
shmem_f32[sgitg] = s;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
s += shmem_f32[sgitg];
|
||||
|
||||
device float * dst_row = (device float *) dst +
|
||||
args.ne00*i01 +
|
||||
args.ne00*args.ne01*i02 +
|
||||
args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
dst_row[i00 + tpitg.x] = s;
|
||||
}
|
||||
|
||||
if (args.outb && tpitg.x == ntg.x - 1) {
|
||||
device float * tmp_row = (device float *) tmp +
|
||||
args.net0*i01 +
|
||||
args.net0*args.net1*i02 +
|
||||
args.net0*args.net1*args.net2*i03;
|
||||
|
||||
tmp_row[ib] = s;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_add(
|
||||
constant ggml_metal_kargs_cumsum_add & args,
|
||||
device const char * tmp,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ib = tgpig[0]/args.ne01;
|
||||
|
||||
if (ib == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * tmp_row = (device const float *) (tmp +
|
||||
args.nbt1*i01 +
|
||||
args.nbt2*i02 +
|
||||
args.nbt3*i03);
|
||||
|
||||
device float * dst_row = (device float *) dst +
|
||||
args.ne00*i01 +
|
||||
args.ne00*args.ne01*i02 +
|
||||
args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
|
||||
@@ -0,0 +1,318 @@
|
||||
#include "common.h"
|
||||
|
||||
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
|
||||
constant bool FC_rope_is_back [[function_constant(FC_ROPE + 1)]];
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
}
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
|
||||
thread float * cos_theta, thread float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
||||
}
|
||||
*cos_theta = cos(theta) * mscale;
|
||||
*sin_theta = sin(theta) * mscale;
|
||||
if (FC_rope_is_back) {
|
||||
*sin_theta *= -1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||
}
|
||||
|
||||
static void rope_yarn_corr_dims(
|
||||
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
) {
|
||||
// start and end correction dims
|
||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_norm(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float theta_base = (float) pos[i2];
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_neox(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float theta_base = (float) pos[i2];
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_multi(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
|
||||
// mrope theta calculations
|
||||
// note: the rest is the same as kernel_rope_neox
|
||||
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
|
||||
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
|
||||
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base;
|
||||
if (FC_rope_is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
|
||||
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
|
||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
|
||||
theta_base = (float) pos[i2 + args.ne02 * 0];
|
||||
} else { // e
|
||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||
}
|
||||
} else {
|
||||
if (sector < args.sect_0) {
|
||||
theta_base = (float) pos[i2];
|
||||
} else if (sector < sec_w01) {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||
} else if (sector < sec_w012) {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||
} else {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||
}
|
||||
}
|
||||
// end of mrope
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_vision(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
|
||||
const int ic = i0/2;
|
||||
|
||||
// mrope theta calculations (only support 2 dimensions)
|
||||
const int sect_dims = args.sect_0 + args.sect_1;
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float p;
|
||||
float theta_base;
|
||||
if (sector < args.sect_1) {
|
||||
p = (float) sector;
|
||||
theta_base = (float) pos[i2];
|
||||
} else {
|
||||
p = (float) sector - args.sect_0;
|
||||
theta_base = (float) pos[i2 + args.ne02];
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
||||
// end of mrope
|
||||
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
||||
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
||||
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
||||
|
||||
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
||||
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
||||
|
||||
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
||||
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
||||
|
||||
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
||||
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
||||
|
||||
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
||||
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
||||
@@ -0,0 +1,223 @@
|
||||
#include "common.h"
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x;
|
||||
|
||||
const int32_t i13 = i03%args.ne13;
|
||||
const int32_t i12 = i02%args.ne12;
|
||||
const int32_t i11 = i01;
|
||||
|
||||
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
|
||||
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (args.max_bias > 0.0f) {
|
||||
const int32_t h = i02;
|
||||
|
||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
// This barrier fixes a failing test
|
||||
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
if (psrc2) {
|
||||
sum += exp(psrc2[i02] - max_val);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||
pdst[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max_4(
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x;
|
||||
|
||||
const int32_t i13 = i03%args.ne13;
|
||||
const int32_t i12 = i02%args.ne12;
|
||||
const int32_t i11 = i01;
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
if (args.max_bias > 0.0f) {
|
||||
const int32_t h = i02;
|
||||
|
||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
||||
float max_val = simd_max(lmax);
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||
|
||||
// This barrier fixes a failing test
|
||||
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
if (psrc2) {
|
||||
sum += exp(psrc2[i02] - max_val);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||
pdst4[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
|
||||
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
|
||||
|
||||
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
|
||||
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
|
||||
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
||||
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
||||
@@ -0,0 +1,75 @@
|
||||
#include "common.h"
|
||||
|
||||
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
|
||||
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
|
||||
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
ushort3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
|
||||
const short NSG = FC_solve_tri_nsg;
|
||||
const short N = FC_solve_tri_n;
|
||||
const short K = FC_solve_tri_k;
|
||||
const short NP = PAD2(N, NW);
|
||||
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x*NSG + sgitg;
|
||||
|
||||
threadgroup float * sh0 = (threadgroup float *) shmem;
|
||||
|
||||
device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
|
||||
device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
|
||||
device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
|
||||
|
||||
for (short rr = 0; rr < N; rr += NSG) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
threadgroup float * sh0_cur = sh0 + sgitg*NP;
|
||||
|
||||
for (short t = 0; t*NW < N; ++t) {
|
||||
const short idx = t*NW + tiisg;
|
||||
sh0_cur[idx] = src0_ptr[idx];
|
||||
}
|
||||
|
||||
src0_ptr += NSG*N;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (i01 >= args.ne10) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
|
||||
const short r = rr + ir;
|
||||
|
||||
threadgroup float * sh0_cur = sh0 + ir*NP;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (short t = 0; t*NW < r; ++t) {
|
||||
const short idx = t*NW + tiisg;
|
||||
sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
|
||||
}
|
||||
|
||||
sum = simd_sum(sum);
|
||||
|
||||
if (tiisg == 0) {
|
||||
const float diag = sh0_cur[r];
|
||||
|
||||
dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
#include "common.h"
|
||||
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
||||
kernel void kernel_ssm_conv_f32_f32(
|
||||
constant ggml_metal_kargs_ssm_conv & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i3 = tgpig.z;
|
||||
|
||||
const int64_t nc = args.ne10;
|
||||
//const int64_t ncs = args.ne00;
|
||||
//const int64_t nr = args.ne01;
|
||||
//const int64_t n_t = args.ne1;
|
||||
//const int64_t n_s = args.ne2;
|
||||
|
||||
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
sumf += s[i0] * c[i0];
|
||||
}
|
||||
|
||||
x[0] = sumf;
|
||||
}
|
||||
|
||||
kernel void kernel_ssm_conv_f32_f32_4(
|
||||
constant ggml_metal_kargs_ssm_conv & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i3 = tgpig.z;
|
||||
|
||||
const int64_t nc = args.ne10;
|
||||
//const int64_t ncs = args.ne00;
|
||||
//const int64_t nr = args.ne01;
|
||||
//const int64_t n_t = args.ne1;
|
||||
//const int64_t n_s = args.ne2;
|
||||
|
||||
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
||||
sumf += dot(s[i0], c[i0]);
|
||||
}
|
||||
|
||||
x[0] = sumf;
|
||||
}
|
||||
|
||||
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
|
||||
|
||||
// Batched version: each threadgroup processes multiple tokens for better efficiency
|
||||
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
|
||||
kernel void kernel_ssm_conv_f32_f32_batched(
|
||||
constant ggml_metal_kargs_ssm_conv & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
// tgpig.x = row index (ir)
|
||||
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
||||
// tgpig.z = sequence index (i3)
|
||||
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
||||
const short BATCH_SIZE = FC_ssm_conv_bs;
|
||||
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2_off = tpitg.x;
|
||||
const int64_t i2 = i2_base + i2_off;
|
||||
|
||||
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
||||
const int64_t n_t = args.ne1; // number of tokens
|
||||
|
||||
// Bounds check for partial batches at the end
|
||||
if (i2 >= n_t) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Load conv weights (shared across all tokens for this row)
|
||||
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
||||
|
||||
// Load source for this specific token
|
||||
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||
|
||||
// Output location for this token
|
||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
sumf += s[i0] * c[i0];
|
||||
}
|
||||
|
||||
x[0] = sumf;
|
||||
}
|
||||
|
||||
kernel void kernel_ssm_conv_f32_f32_batched_4(
|
||||
constant ggml_metal_kargs_ssm_conv & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
// tgpig.x = row index (ir)
|
||||
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
||||
// tgpig.z = sequence index (i3)
|
||||
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
||||
const short BATCH_SIZE = FC_ssm_conv_bs;
|
||||
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2_off = tpitg.x;
|
||||
const int64_t i2 = i2_base + i2_off;
|
||||
|
||||
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
||||
const int64_t n_t = args.ne1; // number of tokens
|
||||
|
||||
// Bounds check for partial batches at the end
|
||||
if (i2 >= n_t) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Load conv weights (shared across all tokens for this row)
|
||||
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
||||
|
||||
// Load source for this specific token
|
||||
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||
|
||||
// Output location for this token
|
||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
||||
sumf += dot(s[i0], c[i0]);
|
||||
}
|
||||
|
||||
x[0] = sumf;
|
||||
}
|
||||
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
||||
// Optimized version: reduces redundant memory loads by having one thread load shared values
|
||||
kernel void kernel_ssm_scan_f32(
|
||||
constant ggml_metal_kargs_ssm_scan & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device const void * src2,
|
||||
device const void * src3,
|
||||
device const void * src4,
|
||||
device const void * src5,
|
||||
device const void * src6,
|
||||
device float * dst,
|
||||
threadgroup float * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgptg[[simdgroups_per_threadgroup]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
|
||||
// Shared memory layout:
|
||||
// [0..sgptg*NW-1]: partial sums for reduction (existing)
|
||||
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
|
||||
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
|
||||
threadgroup float * shared_sums = shared;
|
||||
threadgroup float * shared_x_dt = shared + sgptg * NW;
|
||||
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
|
||||
|
||||
shared_sums[tpitg.x] = 0.0f;
|
||||
|
||||
const int32_t i0 = tpitg.x;
|
||||
const int32_t i1 = tgpig.x;
|
||||
const int32_t ir = tgpig.y; // current head
|
||||
const int32_t i3 = tgpig.z; // current seq
|
||||
|
||||
const int32_t nc = args.d_state;
|
||||
const int32_t nr = args.d_inner;
|
||||
const int32_t nh = args.n_head;
|
||||
const int32_t ng = args.n_group;
|
||||
const int32_t n_t = args.n_seq_tokens;
|
||||
|
||||
const int32_t s_off = args.s_off;
|
||||
|
||||
device const int32_t * ids = (device const int32_t *) src6;
|
||||
|
||||
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
|
||||
const int32_t i = i0 + i1*nc;
|
||||
const int32_t g = ir / (nh / ng); // repeat_interleave
|
||||
|
||||
float s0 = s0_buff[i];
|
||||
float s = 0.0f;
|
||||
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
|
||||
|
||||
const float A0 = A[i0%args.ne30];
|
||||
|
||||
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
|
||||
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||
|
||||
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
|
||||
|
||||
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Pre-compute x_dt and dA for this batch of tokens
|
||||
// Only first sgptg threads do the loads and expensive math
|
||||
if (i0 < sgptg && i2 + i0 < n_t) {
|
||||
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
|
||||
device const float * x_t = x + i0 * args.ns12;
|
||||
device const float * dt_t = dt + i0 * args.ns21;
|
||||
|
||||
const float dt0 = dt_t[0];
|
||||
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
|
||||
shared_x_dt[i0] = x_t[0] * dtsp;
|
||||
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
|
||||
const float x_dt = shared_x_dt[t];
|
||||
const float dA = exp(shared_dA[t] * A0);
|
||||
|
||||
s = (s0 * dA) + (B[i0] * x_dt);
|
||||
|
||||
const float sumf = simd_sum(s * C[i0]);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shared_sums[t*NW + sgitg] = sumf;
|
||||
}
|
||||
|
||||
// recurse
|
||||
s0 = s;
|
||||
|
||||
B += args.ns42;
|
||||
C += args.ns52;
|
||||
}
|
||||
|
||||
// Advance pointers for next batch
|
||||
x += sgptg * args.ns12;
|
||||
dt += sgptg * args.ns21;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
|
||||
|
||||
if (tiisg == 0 && i2 + sgitg < n_t) {
|
||||
y[sgitg*nh*nr] = sumf;
|
||||
}
|
||||
|
||||
y += sgptg*nh*nr;
|
||||
}
|
||||
|
||||
s_buff[i] = s;
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
#include "common.h"
|
||||
|
||||
template<uint32_t ttype>
|
||||
bool _ggml_vec_tri_cmp(const int i, const int r);
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
|
||||
return i < r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
|
||||
return i <= r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
|
||||
return i > r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
|
||||
return i >= r;
|
||||
}
|
||||
|
||||
template<typename T, int ttype>
|
||||
kernel void kernel_tri(
|
||||
constant ggml_metal_kargs_tri & args,
|
||||
device const char * src0,
|
||||
device const char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
// Each thread is a single element of the row if ne00 < max threads per
|
||||
// threadgroup, so this will loop once for each index that this thread is
|
||||
// responsible for
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
// Use the comparison as a mask for branchless
|
||||
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
|
||||
|
||||
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
|
||||
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
|
||||
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
|
||||
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
|
||||
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
|
||||
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
|
||||
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
|
||||
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
|
||||
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
|
||||
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
|
||||
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
|
||||
#endif
|
||||
@@ -0,0 +1,360 @@
|
||||
#include "common.h"
|
||||
|
||||
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
|
||||
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
|
||||
|
||||
template <typename T0, typename T, typename TC>
|
||||
kernel void kernel_unary_impl(
|
||||
constant ggml_metal_kargs_unary & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
#define FC_OP FC_unary_op
|
||||
#define FC_CNT FC_unary_cnt
|
||||
|
||||
device const T0 * src0_ptr;
|
||||
device T * dst_ptr;
|
||||
|
||||
int i0;
|
||||
|
||||
if (FC_CNT) {
|
||||
i0 = tgpig.x;
|
||||
|
||||
src0_ptr = (device const T0 *) (src0);
|
||||
dst_ptr = (device T *) (dst);
|
||||
} else {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int k0 = tgpig.x/args.ne01;
|
||||
const int i01 = tgpig.x - k0*args.ne01;
|
||||
|
||||
i0 = k0*ntg.x + tpitg.x;
|
||||
|
||||
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
|
||||
}
|
||||
|
||||
{
|
||||
//threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (!FC_CNT) {
|
||||
if (i0 >= args.ne0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const TC x = (TC) src0_ptr[i0];
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SCALE) {
|
||||
dst_ptr[i0] = (T) (args.scale * x + args.bias);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_FILL) {
|
||||
dst_ptr[i0] = (T) args.val;
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_CLAMP) {
|
||||
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SQR) {
|
||||
dst_ptr[i0] = (T) (x * x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SQRT) {
|
||||
dst_ptr[i0] = (T) sqrt(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SIN) {
|
||||
dst_ptr[i0] = (T) sin(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_COS) {
|
||||
dst_ptr[i0] = (T) cos(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_LOG) {
|
||||
dst_ptr[i0] = (T) log(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
|
||||
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_TANH) {
|
||||
dst_ptr[i0] = (T) precise::tanh(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_RELU) {
|
||||
dst_ptr[i0] = (T) fmax(0, x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
|
||||
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU) {
|
||||
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
|
||||
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
|
||||
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SILU) {
|
||||
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ELU) {
|
||||
dst_ptr[i0] = (T) elu_approx(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_NEG) {
|
||||
dst_ptr[i0] = (T) -x;
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ABS) {
|
||||
dst_ptr[i0] = (T) fabs(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SGN) {
|
||||
dst_ptr[i0] = T(x > 0) - T(x < 0);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_STEP) {
|
||||
dst_ptr[i0] = T(x > 0);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
|
||||
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
|
||||
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_EXP) {
|
||||
dst_ptr[i0] = (T) exp(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
|
||||
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_EXPM1) {
|
||||
// TODO: precise implementation
|
||||
dst_ptr[i0] = (T) (exp(x) - 1);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_FLOOR) {
|
||||
dst_ptr[i0] = (T) floor(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_CEIL) {
|
||||
dst_ptr[i0] = (T) ceil(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ROUND) {
|
||||
dst_ptr[i0] = (T) round(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_TRUNC) {
|
||||
dst_ptr[i0] = (T) trunc(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_XIELU) {
|
||||
const TC xi = x;
|
||||
const TC gate = TC(xi > TC(0.0f));
|
||||
const TC clamped = fmin(xi, TC(args.val));
|
||||
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
|
||||
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
|
||||
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
#undef FC_CNT
|
||||
}
|
||||
|
||||
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
|
||||
|
||||
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
|
||||
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
|
||||
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
|
||||
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_reglu(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_reglu<float>) kernel_reglu_t;
|
||||
|
||||
template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
|
||||
template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_geglu(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
||||
|
||||
dst_row[i0] = (T)(gelu*x1);
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_geglu<float>) kernel_geglu_t;
|
||||
|
||||
template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
|
||||
template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_swiglu(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float silu = x0 / (1.0f + exp(-x0));
|
||||
|
||||
dst_row[i0] = (T)(silu*x1);
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
|
||||
|
||||
template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
|
||||
template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_swiglu_oai(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
float x0 = src0_row[i0];
|
||||
float x1 = src1_row[i0];
|
||||
|
||||
x0 = min(x0, args.limit);
|
||||
x1 = max(min(x1, args.limit), -args.limit);
|
||||
|
||||
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
|
||||
out_glu = out_glu * (1.0f + x1);
|
||||
|
||||
dst_row[i0] = (T)out_glu;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
|
||||
|
||||
template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
|
||||
template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_geglu_erf(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
||||
|
||||
dst_row[i0] = (T)(gelu_erf*x1);
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
|
||||
|
||||
template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
|
||||
template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_geglu_quick(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
||||
|
||||
dst_row[i0] = (T)(gelu_quick*x1);
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
|
||||
|
||||
template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
|
||||
template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
|
||||
@@ -0,0 +1,179 @@
|
||||
#include "common.h"
|
||||
|
||||
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
|
||||
|
||||
kernel void kernel_upscale_nearest_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3/args.sf3;
|
||||
const int64_t i02 = i2/args.sf2;
|
||||
const int64_t i01 = i1/args.sf1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int64_t i00 = i0/args.sf0;
|
||||
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_ptr[0] = src0_ptr[0];
|
||||
}
|
||||
}
|
||||
|
||||
static inline float bilinear_tri(float x) {
|
||||
return MAX(0.0f, 1.0f - fabs(x));
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_bilinear_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3 / args.sf3;
|
||||
const int64_t i02 = i2 / args.sf2;
|
||||
|
||||
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
||||
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
|
||||
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
|
||||
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
|
||||
|
||||
src0 += i03*args.nb03 + i02*args.nb02;
|
||||
|
||||
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
if (FC_upscale_aa) {
|
||||
const float support0 = MAX(1.0f, 1.0f / args.sf0);
|
||||
const float invscale0 = 1.0f / support0;
|
||||
const float support1 = MAX(1.0f, 1.0f / args.sf1);
|
||||
const float invscale1 = 1.0f / support1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
|
||||
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
|
||||
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
|
||||
|
||||
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
|
||||
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
|
||||
|
||||
float sum = 0.0f;
|
||||
float wsum = 0.0f;
|
||||
|
||||
for (int64_t sy = y_min; sy < y_max; ++sy) {
|
||||
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
|
||||
for (int64_t sx = x_min; sx < x_max; ++sx) {
|
||||
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
|
||||
const float w = wx * wy;
|
||||
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
||||
sum += (*src_ptr) * w;
|
||||
wsum += w;
|
||||
}
|
||||
}
|
||||
|
||||
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
|
||||
dst_ptr[i0] = v;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
|
||||
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
|
||||
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
|
||||
|
||||
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
|
||||
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
|
||||
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
|
||||
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
|
||||
|
||||
const float v =
|
||||
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
|
||||
(*src10) * fd0 * (1.0f - fd1) +
|
||||
(*src01) * (1.0f - fd0) * fd1 +
|
||||
(*src11) * fd0 * fd1;
|
||||
|
||||
dst_ptr[i0] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline float bicubic_weight1(float x) {
|
||||
const float a = -0.75f;
|
||||
return ((a + 2) * x - (a + 3)) * x * x + 1;
|
||||
}
|
||||
|
||||
static inline float bicubic_weight2(float x) {
|
||||
const float a = -0.75f;
|
||||
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_bicubic_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3 / args.sf3;
|
||||
const int64_t i02 = i2 / args.sf2;
|
||||
|
||||
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
||||
const int64_t i01 = (int64_t)floor(f01);
|
||||
const float fd1 = f01 - (float)i01;
|
||||
|
||||
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
|
||||
const float w_y1 = bicubic_weight1(fd1);
|
||||
const float w_y2 = bicubic_weight1(1.0f - fd1);
|
||||
const float w_y3 = bicubic_weight2(2.0f - fd1);
|
||||
|
||||
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
|
||||
|
||||
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
const int64_t i00 = (int64_t)floor(f00);
|
||||
const float fd0 = f00 - (float)i00;
|
||||
|
||||
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
|
||||
const float w_x1 = bicubic_weight1(fd0);
|
||||
const float w_x2 = bicubic_weight1(1.0f - fd0);
|
||||
const float w_x3 = bicubic_weight2(2.0f - fd0);
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int dy = -1; dy <= 2; ++dy) {
|
||||
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
|
||||
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
|
||||
|
||||
for (int dx = -1; dx <= 2; ++dx) {
|
||||
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
|
||||
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
|
||||
|
||||
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
||||
sum += (*src_ptr) * wx * wy;
|
||||
}
|
||||
}
|
||||
|
||||
dst_ptr[i0] = sum;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
#include "common.h"
|
||||
|
||||
kernel void kernel_rwkv_wkv6_f32(
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * r,
|
||||
device const float * tf,
|
||||
device const float * td,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _tf[head_size];
|
||||
threadgroup float _td[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_tf[tid] = tf[head_id * head_size + tid];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0;
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
float4 temp = tf_vec * kv + s_vec;
|
||||
y += dot(r_vec, temp);
|
||||
|
||||
s_vec = s_vec * td_vec + kv;
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rwkv_wkv7_f32(
|
||||
device const float * r,
|
||||
device const float * w,
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * a,
|
||||
device const float * b,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _w[head_size];
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _a[head_size];
|
||||
threadgroup float _b[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i];
|
||||
}
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0, sa = 0.0;
|
||||
|
||||
float4 sa_vec(0.0);
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
sa_vec += a_vec * s_vec;
|
||||
}
|
||||
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||
y += dot(s_vec, r_vec);
|
||||
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
|
||||
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
|
||||
+24
-99
@@ -1562,112 +1562,37 @@ static void test_msgs_oaicompat_json_conversion() {
|
||||
}
|
||||
}
|
||||
|
||||
static void test_msg_token_delimiters_split() {
|
||||
static void test_split_by_role() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
|
||||
// Delimiters that share a leading token, distinguished by the second token,
|
||||
// to exercise the per-position token matching.
|
||||
const common_chat_msg_delimiters delims = {
|
||||
{ { COMMON_CHAT_ROLE_USER, "", { 10, 11 } },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } }
|
||||
};
|
||||
|
||||
// Empty inputs
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({}).spans.size());
|
||||
assert_equals<size_t>(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size());
|
||||
assert_equals<size_t>(0, delims.split({}).spans.size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
|
||||
|
||||
// No delimiters match -> no spans
|
||||
assert_equals<size_t>(0, delims.split({ 100, 101, 102 }).spans.size());
|
||||
|
||||
// Multi-role conversation: <user>Hi<assistant>Hello<user>Bye
|
||||
// Multi-role conversation, no leading/trailing content
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
100, 101, // Hi
|
||||
10, 12, // <assistant>
|
||||
200, 201, 202, // Hello
|
||||
10, 11, // <user>
|
||||
300, 301, // Bye
|
||||
};
|
||||
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
|
||||
const auto splits = common_chat_split_by_role(prompt, {
|
||||
{ "user", "<|user|>" },
|
||||
{ "assistant", "<|assistant|>" },
|
||||
});
|
||||
assert_equals<size_t>(3, splits.size());
|
||||
|
||||
const auto result = delims.split(tokens);
|
||||
const auto & spans = result.spans;
|
||||
assert_equals<size_t>(3, spans.size());
|
||||
assert_equals<std::string>("user", splits[0].role);
|
||||
assert_equals<size_t>(0, splits[0].pos);
|
||||
assert_equals<size_t>(10, splits[0].len);
|
||||
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(4, spans[0].len);
|
||||
assert_equals<std::string>("assistant", splits[1].role);
|
||||
assert_equals<size_t>(10, splits[1].pos);
|
||||
assert_equals<size_t>(18, splits[1].len);
|
||||
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(4, spans[1].pos);
|
||||
assert_equals<size_t>(5, spans[1].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role);
|
||||
assert_equals<size_t>(9, spans[2].pos);
|
||||
assert_equals<size_t>(4, spans[2].len);
|
||||
|
||||
// is_user_start() is true at the token position where a user span begins
|
||||
assert_equals(true, result.is_user_start(0));
|
||||
assert_equals(false, result.is_user_start(4)); // assistant span
|
||||
assert_equals(true, result.is_user_start(9));
|
||||
}
|
||||
|
||||
// Content before the first delimiter is not captured as a span
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
500, 501, // leading content (dropped)
|
||||
10, 11, // <user>
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const auto spans = delims.split(tokens).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(2, spans[0].pos);
|
||||
assert_equals<size_t>(3, spans[0].len);
|
||||
}
|
||||
|
||||
// Skipped regions (media chunks) are jumped over but still count as span content
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
LLAMA_TOKEN_NULL, // media chunk (3 tokens)
|
||||
LLAMA_TOKEN_NULL,
|
||||
LLAMA_TOKEN_NULL,
|
||||
100, // Hi
|
||||
10, 12, // <assistant>
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 3 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(2, spans.size());
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(6, spans[0].len);
|
||||
|
||||
assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role);
|
||||
assert_equals<size_t>(6, spans[1].pos);
|
||||
assert_equals<size_t>(2, spans[1].len);
|
||||
}
|
||||
|
||||
// A delimiter sequence inside a skipped region is not matched
|
||||
{
|
||||
const llama_tokens tokens = {
|
||||
10, 11, // <user>
|
||||
10, 12, // skipped region that happens to contain delimiter tokens
|
||||
100, // Hi
|
||||
};
|
||||
|
||||
const std::map<size_t, size_t> skips = { { 2, 2 } };
|
||||
|
||||
const auto spans = delims.split(tokens, skips).spans;
|
||||
assert_equals<size_t>(1, spans.size());
|
||||
assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role);
|
||||
assert_equals<size_t>(0, spans[0].pos);
|
||||
assert_equals<size_t>(5, spans[0].len);
|
||||
assert_equals<std::string>("user", splits[2].role);
|
||||
assert_equals<size_t>(28, splits[2].pos);
|
||||
assert_equals<size_t>(11, splits[2].len);
|
||||
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5932,7 +5857,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
test_msg_diffs_compute();
|
||||
test_msgs_oaicompat_json_conversion();
|
||||
test_msg_token_delimiters_split();
|
||||
test_split_by_role();
|
||||
test_tools_oaicompat_json_conversion();
|
||||
test_convert_responses_to_chatcmpl();
|
||||
test_developer_role_to_system_workaround();
|
||||
|
||||
@@ -204,9 +204,9 @@ Instead of building everything from the ground up (like what most AI agents will
|
||||
|
||||
The flow for downloading a new model:
|
||||
- POST request comes in --> `post_router_models` --> validation
|
||||
- A new `llama-server` subprocess will be spawned with special `SERVER_CHILD_MODE_DOWNLOAD`
|
||||
- Child process runs the download and report status back to router via stdin/out
|
||||
- If a stop request comes in, the router asks the child process to stop (same mechanism as running a model in child process)
|
||||
- `server_models::download()` is called
|
||||
- Sets up a new thread `inst.th` and runs the download inside
|
||||
- If a stop request comes in, set `stop_download` to `true`
|
||||
- Otherwise, upon completion, we call `load_models()` to refresh the list of models
|
||||
|
||||
### Notable Related PRs
|
||||
|
||||
+5
-12
@@ -1230,6 +1230,8 @@ print(completion.choices[0].text)
|
||||
|
||||
Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
|
||||
|
||||
If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more.
|
||||
|
||||
*Options:*
|
||||
|
||||
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported.
|
||||
@@ -1248,18 +1250,9 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":
|
||||
|
||||
`parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template).
|
||||
|
||||
For multimodal input (typed content, `messages[i].content[j]`):
|
||||
- If `type == "image_url"`:
|
||||
- `image_url.url` can be a remote URL, base64 (raw or URI-encoded via `data:image/...;base64`) or path to local file
|
||||
- Accepts formats supported by `stb_image` (jpeg, png, tga, bmp, gif, ...)
|
||||
- If `type == "input_audio"`:
|
||||
- Either `input_audio.data` or `input_audio.url` can be specified, can be a remote URL, raw base64 or path to local file
|
||||
- Accepts formats supported by `miniaudio` (mp3, wav, flac)
|
||||
- `input_audio.format` will be ignored, the file format will be determined automatically
|
||||
- If `type == "input_video"`:
|
||||
- Either `input_video.data` or `input_video.url` can be specified, can be a remote URL, raw base64 or path to local file
|
||||
- Accepts formats supported by `ffmpeg`
|
||||
- Note: for local file, make sure to set `--media-path`. File path must be prefixed by `file://`
|
||||
For multimodal input:
|
||||
- Content type `image_url` and `input_audio` are the same as OAI schema
|
||||
- Content type `input_video` is an extension from OAI schema. For now, it only accepts base64 input
|
||||
|
||||
*Examples:*
|
||||
|
||||
|
||||
@@ -518,14 +518,6 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const {
|
||||
return max_idx; // all tokens are equal
|
||||
}
|
||||
|
||||
common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const {
|
||||
std::map<size_t, size_t> skips;
|
||||
for (const auto & it : map_idx_to_media) {
|
||||
skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get());
|
||||
}
|
||||
return delims.split(tokens, skips);
|
||||
}
|
||||
|
||||
bool server_tokens::validate(const struct llama_context * ctx) const {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
@@ -825,21 +817,12 @@ json oaicompat_completion_params_parse(const json & body) {
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
// url can be
|
||||
// - http(s):// for remote files
|
||||
// - file:// for local files (only allowed if media_path is set)
|
||||
// - data: for base64 encoded data with uri scheme (e.g. data:image/png;base64,...)
|
||||
// - raw base64 encoded data
|
||||
// media_path always end with '/', see arg.cpp
|
||||
static void handle_media(
|
||||
std::vector<raw_buffer> & out_files,
|
||||
const std::string & url,
|
||||
const std::string & media_path,
|
||||
bool accept_base64_uri) {
|
||||
if (!media_path.empty()) {
|
||||
// should already be enforced by arg.cpp, but checking just in case
|
||||
GGML_ASSERT(media_path.back() == DIRECTORY_SEPARATOR);
|
||||
}
|
||||
|
||||
json & media_obj,
|
||||
const std::string & media_path) {
|
||||
std::string url = json_value(media_obj, "url", std::string());
|
||||
if (string_starts_with(url, "http")) {
|
||||
// download remote image
|
||||
// TODO @ngxson : maybe make these params configurable
|
||||
@@ -875,28 +858,20 @@ static void handle_media(
|
||||
data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
out_files.push_back(data);
|
||||
|
||||
} else if (accept_base64_uri && string_starts_with(url, "data:")) {
|
||||
} else {
|
||||
// try to decode base64 image
|
||||
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
|
||||
if (parts.size() != 2) {
|
||||
throw std::runtime_error("Invalid uri-encoded base64 value");
|
||||
throw std::runtime_error("Invalid url value");
|
||||
} else if (!string_starts_with(parts[0], "data:image/")) {
|
||||
throw std::runtime_error("Invalid uri format: " + parts[0]);
|
||||
throw std::runtime_error("Invalid url format: " + parts[0]);
|
||||
} else if (!string_ends_with(parts[0], "base64")) {
|
||||
throw std::runtime_error("uri must be base64 encoded");
|
||||
throw std::runtime_error("url must be base64 encoded");
|
||||
} else {
|
||||
auto base64_data = parts[1];
|
||||
auto decoded_data = base64_decode(base64_data);
|
||||
out_files.push_back(decoded_data);
|
||||
}
|
||||
|
||||
} else {
|
||||
// try as raw base64 string
|
||||
auto decoded_data = base64_decode(url);
|
||||
if (decoded_data.empty()) {
|
||||
throw std::runtime_error("Invalid base64 value");
|
||||
}
|
||||
out_files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -982,15 +957,14 @@ json oaicompat_chat_params_parse(
|
||||
}
|
||||
|
||||
for (auto & p : content) {
|
||||
std::string type = json_value(p, "type", std::string());
|
||||
std::string type = json_value(p, "type", std::string());
|
||||
if (type == "image_url") {
|
||||
if (!opt.allow_image) {
|
||||
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json image_url = json_value(p, "image_url", json::object());
|
||||
std::string url = json_value(image_url, "url", std::string());
|
||||
handle_media(out_files, url, opt.media_path, true);
|
||||
handle_media(out_files, image_url, opt.media_path);
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
@@ -1001,11 +975,17 @@ json oaicompat_chat_params_parse(
|
||||
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
// note: don't need to validate "format", it's redundant
|
||||
json input_audio = json_value(p, "input_audio", json::object());
|
||||
std::string url = json_value(input_audio, "data",
|
||||
json_value(input_audio, "url", std::string()));
|
||||
handle_media(out_files, url, opt.media_path, false);
|
||||
json input_audio = json_value(p, "input_audio", json::object());
|
||||
std::string data = json_value(input_audio, "data", std::string());
|
||||
std::string format = json_value(input_audio, "format", std::string());
|
||||
// while we also support flac, we don't allow it here so we matches the OAI spec
|
||||
if (format != "wav" && format != "mp3") {
|
||||
throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'");
|
||||
}
|
||||
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
||||
out_files.push_back(decoded_data);
|
||||
|
||||
// TODO: add audio_url support by reusing handle_media()
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
@@ -1016,10 +996,10 @@ json oaicompat_chat_params_parse(
|
||||
throw std::runtime_error("video input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json input_video = json_value(p, "input_video", json::object());
|
||||
std::string url = json_value(input_video, "data",
|
||||
json_value(input_video, "url", std::string()));
|
||||
handle_media(out_files, url, opt.media_path, false);
|
||||
json input_video = json_value(p, "input_video", json::object());
|
||||
std::string data = json_value(input_video, "data", std::string());
|
||||
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
||||
out_files.push_back(decoded_data);
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
@@ -1112,7 +1092,15 @@ json oaicompat_chat_params_parse(
|
||||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
llama_params["message_delimiters"] = chat_params.message_delimiters.to_json();
|
||||
llama_params["message_spans"] = json::array();
|
||||
|
||||
for (const auto & span : chat_params.message_spans) {
|
||||
llama_params["message_spans"].push_back({
|
||||
{ "role", span.role },
|
||||
{ "pos", span.pos },
|
||||
{ "len", span.len },
|
||||
});
|
||||
}
|
||||
|
||||
// Reasoning budget: pass parameters through to sampling layer
|
||||
{
|
||||
|
||||
@@ -218,9 +218,6 @@ public:
|
||||
|
||||
size_t get_common_prefix(const server_tokens & b) const;
|
||||
|
||||
// split the tokens into message spans, skipping over media chunks
|
||||
common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const;
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context * ctx) const;
|
||||
|
||||
|
||||
@@ -931,8 +931,6 @@ private:
|
||||
|
||||
bool sleeping = false;
|
||||
|
||||
int64_t t_last_load_progress_ms = 0;
|
||||
|
||||
void destroy() {
|
||||
spec.reset();
|
||||
ctx_dft.reset();
|
||||
@@ -1246,10 +1244,6 @@ private:
|
||||
}
|
||||
|
||||
if (has_mmproj) {
|
||||
if (callback_state) {
|
||||
callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}});
|
||||
}
|
||||
|
||||
if (!is_resume) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
}
|
||||
@@ -3436,8 +3430,8 @@ private:
|
||||
has_mtmd = true;
|
||||
}
|
||||
|
||||
const auto & spans = slot.task->params.message_spans;
|
||||
const auto last_user_pos = spans.last_user_message_pos();
|
||||
const int32_t n_before_user = slot.task->params.n_before_user;
|
||||
const bool n_before_user_known = n_before_user > 0;
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) {
|
||||
@@ -3466,8 +3460,10 @@ private:
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
||||
// stop the prompt batch exactly before a user message
|
||||
if (spans.is_user_start(slot.prompt.n_tokens())) {
|
||||
// stop the prompt batch exactly before the latest user input, so a checkpoint
|
||||
// can be created after the previous messages
|
||||
if (n_before_user_known &&
|
||||
slot.prompt.n_tokens() == n_before_user) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -3496,13 +3492,8 @@ private:
|
||||
// the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_cur = batch.size() - n_tokens_prev;
|
||||
|
||||
const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
|
||||
|
||||
const bool is_user_start = spans.is_user_start(n_tokens_start);
|
||||
const bool is_last_user_message = n_tokens_start == last_user_pos;
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
@@ -3517,9 +3508,8 @@ private:
|
||||
|
||||
slot.init_sampler();
|
||||
} else {
|
||||
// skip ordinary mid-prompt checkpoints, unless the batch starts a user
|
||||
// message or we are near the end of the prompt
|
||||
if (!is_user_start && !near_prompt_end) {
|
||||
// skip ordinary mid-prompt checkpoints
|
||||
if (!n_before_user_known && !near_prompt_end) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
@@ -3527,6 +3517,29 @@ private:
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
|
||||
|
||||
// checkpoints are created before the current batch is decoded, so
|
||||
// their token position is the batch start rather than the prompt end
|
||||
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
{
|
||||
const bool is_on_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start == n_before_user;
|
||||
|
||||
const bool is_after_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start > n_before_user;
|
||||
|
||||
const bool is_allowed =
|
||||
!n_before_user_known ||
|
||||
is_on_user ||
|
||||
(is_after_user && near_prompt_end);
|
||||
|
||||
if (do_checkpoint && !is_allowed) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
|
||||
// nothing to checkpoint yet
|
||||
// TODO: is this check needed?
|
||||
if (do_checkpoint && pos_min < 0) {
|
||||
@@ -3536,8 +3549,8 @@ private:
|
||||
// do not checkpoint after mtmd chunks
|
||||
do_checkpoint = do_checkpoint && !has_mtmd;
|
||||
|
||||
// no need to create checkpoints that are too close together, unless it's the last user message
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
@@ -4036,6 +4049,54 @@ void server_context::set_state_callback(server_state_callback_t callback) {
|
||||
});
|
||||
}
|
||||
|
||||
// compute the number of tokens before the last user message in the prompt
|
||||
static int32_t prompt_get_n_before_user(
|
||||
const json & message_spans,
|
||||
const std::string & prompt,
|
||||
const std::vector<raw_buffer> & files,
|
||||
const llama_vocab * vocab,
|
||||
mtmd_context * mctx) {
|
||||
int32_t result = -1;
|
||||
int32_t byte_pos = -1;
|
||||
|
||||
for (const auto & span : message_spans) {
|
||||
const std::string role = json_value(span, "role", std::string());
|
||||
|
||||
if (role == "user") {
|
||||
byte_pos = json_value(span, "pos", -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (byte_pos >= 0) {
|
||||
GGML_ASSERT((size_t) byte_pos <= prompt.size());
|
||||
|
||||
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
|
||||
|
||||
const std::string marker = get_media_marker();
|
||||
size_t n_prefix_media = 0;
|
||||
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
|
||||
n_prefix_media++;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_prefix_media <= files.size());
|
||||
|
||||
if (mctx != nullptr && n_prefix_media > 0) {
|
||||
// TODO: this makes a copy - avoid it
|
||||
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
|
||||
|
||||
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
|
||||
} else {
|
||||
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
|
||||
}
|
||||
|
||||
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
|
||||
byte_pos, n_prefix_media, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// server_routes
|
||||
//
|
||||
@@ -4083,10 +4144,6 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
|
||||
// tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks
|
||||
|
||||
// message delimiters for checkpointing
|
||||
auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array()));
|
||||
delimiters.tokenize(ctx_server.vocab);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
@@ -4100,7 +4157,16 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
meta->logit_bias_eog,
|
||||
data);
|
||||
|
||||
task.params.message_spans = task.tokens.find_message_spans(delimiters);
|
||||
const auto message_spans = json_value(data, "message_spans", json::array());
|
||||
if (prompt.is_string() && message_spans.is_array()) {
|
||||
task.params.n_before_user =
|
||||
prompt_get_n_before_user(
|
||||
message_spans,
|
||||
prompt.get<std::string>(),
|
||||
files,
|
||||
ctx_server.vocab,
|
||||
ctx_server.mctx);
|
||||
}
|
||||
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ struct server_context_meta {
|
||||
};
|
||||
|
||||
enum server_state {
|
||||
SERVER_STATE_DOWNLOADING,
|
||||
// SERVER_STATE_DOWNLOADING,
|
||||
SERVER_STATE_LOADING,
|
||||
SERVER_STATE_READY,
|
||||
SERVER_STATE_SLEEPING,
|
||||
@@ -61,7 +61,6 @@ enum server_state {
|
||||
|
||||
static std::string server_state_to_str(server_state state) {
|
||||
switch (state) {
|
||||
case SERVER_STATE_DOWNLOADING: return "downloading";
|
||||
case SERVER_STATE_LOADING: return "loading";
|
||||
case SERVER_STATE_READY: return "ready";
|
||||
case SERVER_STATE_SLEEPING: return "sleeping";
|
||||
@@ -70,7 +69,6 @@ static std::string server_state_to_str(server_state state) {
|
||||
}
|
||||
|
||||
static server_state server_state_from_str(const std::string & str) {
|
||||
if (str == "downloading") return SERVER_STATE_DOWNLOADING;
|
||||
if (str == "loading") return SERVER_STATE_LOADING;
|
||||
if (str == "ready") return SERVER_STATE_READY;
|
||||
if (str == "sleeping") return SERVER_STATE_SLEEPING;
|
||||
|
||||
+128
-228
@@ -64,17 +64,6 @@ struct server_subproc {
|
||||
return sproc.has_value() && subprocess_alive(&sproc.value());
|
||||
}
|
||||
|
||||
void request_exit() {
|
||||
if (sproc.has_value()) {
|
||||
FILE * stdin_file = subprocess_stdin(&sproc.value());
|
||||
if (stdin_file) {
|
||||
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
||||
fflush(stdin_file);
|
||||
}
|
||||
}
|
||||
stopped.store(true, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void terminate() {
|
||||
if (!sproc.has_value()) {
|
||||
return;
|
||||
@@ -334,7 +323,7 @@ void server_models::notify_sse(const std::string & event, const std::string & mo
|
||||
}
|
||||
|
||||
void server_models::load_models() {
|
||||
// Phase 1: load presets from all sources - pure I/O, no lock needed
|
||||
// Phase 1: load presets from all sources — pure I/O, no lock needed
|
||||
// 1. cached models
|
||||
common_presets cached_models = ctx_preset.load_from_cache();
|
||||
SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
|
||||
@@ -387,7 +376,7 @@ void server_models::load_models() {
|
||||
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
|
||||
};
|
||||
|
||||
// Helpers that read `mapping` - must be called while holding the lock.
|
||||
// Helpers that read `mapping` — must be called while holding the lock.
|
||||
std::unordered_set<std::string> custom_names;
|
||||
for (const auto & [name, preset] : custom_presets) custom_names.insert(name);
|
||||
auto join_set = [](const std::set<std::string> & s) {
|
||||
@@ -534,7 +523,7 @@ void server_models::load_models() {
|
||||
}
|
||||
}
|
||||
|
||||
// join outside the lock - monitoring thread calls update_status (needs lock)
|
||||
// join outside the lock — monitoring thread calls update_status (needs lock)
|
||||
lk.unlock();
|
||||
for (auto & th : threads_to_join) th.join();
|
||||
lk.lock();
|
||||
@@ -633,7 +622,7 @@ void server_models::load_models() {
|
||||
|
||||
apply_stop_timeout();
|
||||
|
||||
// clear reload flag before unlocking for autoload - load() blocks on !is_reloading,
|
||||
// clear reload flag before unlocking for autoload — load() blocks on !is_reloading,
|
||||
// so clearing it here (while still locked) prevents a deadlock in the autoload calls below
|
||||
is_reloading = false;
|
||||
cv.notify_all();
|
||||
@@ -826,23 +815,17 @@ void server_models::unload_lru() {
|
||||
}
|
||||
|
||||
void server_models::load(const std::string & name) {
|
||||
load(name, load_options{});
|
||||
}
|
||||
|
||||
void server_models::load(const std::string & name, const load_options & opts) {
|
||||
if (!opts.custom_meta.has_value()) {
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
unload_lru();
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
unload_lru();
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// edge case: block until any in-progress reload has finished so we always load
|
||||
// against the freshest preset and a consistent mapping state
|
||||
cv.wait(lk, [this]() { return !is_reloading; });
|
||||
|
||||
auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta;
|
||||
auto meta = mapping[name].meta;
|
||||
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
|
||||
SRV_INF("model %s is not ready\n", name.c_str());
|
||||
return;
|
||||
@@ -886,12 +869,6 @@ void server_models::load(const std::string & name, const load_options & opts) {
|
||||
std::vector<std::string> child_env = base_env; // copy
|
||||
child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
|
||||
|
||||
if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
child_env.push_back("LLAMA_SERVER_CHILD_MODE=download");
|
||||
child_env.push_back("LLAMA_ARG_HF_REPO=" + name);
|
||||
}
|
||||
|
||||
SRV_INF("%s", "spawning server instance with args:\n");
|
||||
for (const auto & arg : child_args) {
|
||||
SRV_INF(" %s\n", arg.c_str());
|
||||
@@ -909,17 +886,13 @@ void server_models::load(const std::string & name, const load_options & opts) {
|
||||
if (result != 0) {
|
||||
throw std::runtime_error("failed to spawn server instance");
|
||||
}
|
||||
|
||||
inst.stdin_file = subprocess_stdin(&inst.subproc->get());
|
||||
}
|
||||
|
||||
// start a thread to manage the child process
|
||||
// captured variables are guaranteed to be destroyed only after the thread is joined
|
||||
inst.th = std::thread([
|
||||
this, name,
|
||||
child_proc = inst.subproc,
|
||||
port = inst.meta.port,
|
||||
stop_timeout = inst.meta.stop_timeout,
|
||||
child_mode = opts.mode
|
||||
]() {
|
||||
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
|
||||
FILE * stdin_file = subprocess_stdin(&child_proc->get());
|
||||
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
|
||||
|
||||
@@ -952,7 +925,7 @@ void server_models::load(const std::string & name, const load_options & opts) {
|
||||
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
|
||||
});
|
||||
}
|
||||
// child crashed or finished on its own, skip graceful shutdown sequence
|
||||
// child crashed or finished on its own — skip graceful shutdown sequence
|
||||
if (child_proc->stopped.load(std::memory_order_acquire)) {
|
||||
return;
|
||||
}
|
||||
@@ -1000,14 +973,10 @@ void server_models::load(const std::string & name, const load_options & opts) {
|
||||
subprocess_destroy(&child_proc->get());
|
||||
|
||||
// update status and exit code
|
||||
if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
// instance will be cleaned up on next load_models() call
|
||||
} else {
|
||||
this->update_status(name, {
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
exit_code
|
||||
});
|
||||
}
|
||||
this->update_status(name, {
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
exit_code
|
||||
});
|
||||
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
|
||||
});
|
||||
|
||||
@@ -1015,7 +984,7 @@ void server_models::load(const std::string & name, const load_options & opts) {
|
||||
{
|
||||
auto & old_instance = mapping[name];
|
||||
// old process should have exited already, but just in case, we clean it up here
|
||||
if (old_instance.subproc && old_instance.subproc->is_alive()) {
|
||||
if (old_instance.subproc->is_alive()) {
|
||||
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
|
||||
old_instance.subproc->terminate(); // force kill
|
||||
}
|
||||
@@ -1032,13 +1001,92 @@ void server_models::load(const std::string & name, const load_options & opts) {
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
// callback for model downloading functionality
|
||||
struct server_models_download_res : public common_download_callback {
|
||||
common_params_model model;
|
||||
common_download_opts opts;
|
||||
|
||||
std::function<bool()> should_stop;
|
||||
std::function<void(const common_download_progress & p)> on_progress;
|
||||
|
||||
bool is_ok = false;
|
||||
|
||||
bool run() {
|
||||
try {
|
||||
common_download_model(model, opts);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = model.get_name();
|
||||
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
|
||||
is_ok = false;
|
||||
}
|
||||
return is_ok;
|
||||
}
|
||||
void on_start(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_update(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_done(const common_download_progress &, bool ok) override {
|
||||
is_ok = ok;
|
||||
}
|
||||
bool is_cancelled() const override {
|
||||
return should_stop();
|
||||
}
|
||||
};
|
||||
|
||||
void server_models::download(common_params_model && model, common_download_opts && opts) {
|
||||
std::string name = model.get_name();
|
||||
GGML_ASSERT(name == model.hf_repo);
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
if (mapping.find(name) != mapping.end()) {
|
||||
throw std::runtime_error("model name=" + name + " already exists");
|
||||
}
|
||||
|
||||
instance_t inst;
|
||||
inst.meta.name = name;
|
||||
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
inst.subproc = std::make_shared<server_subproc>();
|
||||
|
||||
auto dl = std::make_unique<server_models_download_res>();
|
||||
dl->model = model; // copy
|
||||
dl->opts = opts; // copy
|
||||
|
||||
dl->should_stop = [sp = inst.subproc]() {
|
||||
return sp->stopped.load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
dl->on_progress = [this, name](const common_download_progress & p) {
|
||||
update_download_progress(name, p, false);
|
||||
};
|
||||
|
||||
inst.th = std::thread([this, dl = std::move(dl)]() {
|
||||
dl->opts.callback = dl.get();
|
||||
bool ok = dl->run();
|
||||
auto model_name = dl->model.get_name();
|
||||
SRV_INF("download finished for model name=%s with status=%s\n",
|
||||
model_name.c_str(), ok ? "success" : "failure");
|
||||
update_download_progress(model_name, {}, true, ok);
|
||||
// need_reload is set inside update_download_progress under the mutex;
|
||||
// the next load_models() call will clean up this instance
|
||||
});
|
||||
|
||||
mapping[name] = std::move(inst);
|
||||
notify_sse("status_update", name, {
|
||||
{"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)},
|
||||
});
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
void server_models::unload(const std::string & name) {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->request_exit();
|
||||
it->second.subproc->stopped.store(true, std::memory_order_relaxed);
|
||||
// for convenience, we wait the status change here
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
@@ -1150,65 +1198,37 @@ void server_models::update_download_progress(const std::string & name, const com
|
||||
}
|
||||
|
||||
bool server_models::remove(const std::string & name) {
|
||||
// do everything under one lock acquisition; avoid get_meta() /
|
||||
// unload() because they can trigger load_models() which erases
|
||||
// transient DOWNLOADING / DOWNLOADED entries as a side-effect
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto meta = get_meta(name);
|
||||
|
||||
auto it = mapping.find(name);
|
||||
if (it == mapping.end()) {
|
||||
if (!meta.has_value()) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) {
|
||||
if (meta->source != SERVER_MODEL_SOURCE_CACHE) {
|
||||
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
|
||||
}
|
||||
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
// cancel in-flight download
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->request_exit();
|
||||
} else if (it->second.meta.is_running()) {
|
||||
// stop running instance
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
stopping_models.insert(name);
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
|
||||
it->second.subproc->terminate();
|
||||
unload(name); // cancel download or stop running instance
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status == SERVER_MODEL_STATUS_UNLOADED
|
||||
|| new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
});
|
||||
// join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no
|
||||
// longer acquires this mutex, so joining while holding it is safe
|
||||
if (mapping[name].th.joinable()) {
|
||||
mapping[name].th.join();
|
||||
}
|
||||
cv_stop.notify_all();
|
||||
}
|
||||
|
||||
// wait until the monitoring thread finishes
|
||||
wait(lk, name, [](const server_model_meta & meta) {
|
||||
return meta.status == SERVER_MODEL_STATUS_UNLOADED
|
||||
|| meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
});
|
||||
|
||||
// re-find after wait - load_models() may have erased the entry during the wait
|
||||
it = mapping.find(name);
|
||||
if (it == mapping.end()) {
|
||||
// load_models() already joined the thread and erased the entry;
|
||||
// we just need to clean up the cached files on disk
|
||||
lk.unlock();
|
||||
// remove the model from disk (hold lock to prevent concurrent load)
|
||||
bool ok = common_download_remove(name);
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
||||
if (ok) {
|
||||
mapping.erase(name);
|
||||
}
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed");
|
||||
notify_sse("model_remove", name, {});
|
||||
return true;
|
||||
return ok;
|
||||
}
|
||||
|
||||
// join before erasing - thread no longer acquires this mutex
|
||||
if (it->second.th.joinable()) {
|
||||
it->second.th.join();
|
||||
}
|
||||
|
||||
// remove from disk (best-effort: cancelled downloads may have no cached files)
|
||||
bool ok = common_download_remove(name);
|
||||
mapping.erase(name);
|
||||
if (!ok) {
|
||||
SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str());
|
||||
}
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
||||
notify_sse("model_remove", name, {});
|
||||
return true;
|
||||
}
|
||||
|
||||
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
|
||||
@@ -1223,9 +1243,7 @@ void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string &
|
||||
return predicate(it->second.meta);
|
||||
|
||||
}
|
||||
// model was removed from mapping by another code path (e.g. load_models()).
|
||||
// nothing left to wait for - tell the caller to proceed.
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1310,31 +1328,6 @@ void server_models::handle_child_state(const std::string & name, const std::stri
|
||||
}
|
||||
|
||||
switch (state) {
|
||||
case SERVER_STATE_DOWNLOADING:
|
||||
{
|
||||
std::string result = json_value(payload, "result", std::string());
|
||||
std::string url = json_value(payload, "url", std::string());
|
||||
auto request_exit = [&]() {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
return it->second.subproc->request_exit();
|
||||
}
|
||||
};
|
||||
if (result == "download_finished") {
|
||||
update_download_progress(name, {}, true, true);
|
||||
request_exit();
|
||||
} else if (result == "download_failed") {
|
||||
update_download_progress(name, {}, true, false);
|
||||
request_exit();
|
||||
} else if (!url.empty()) {
|
||||
common_download_progress p;
|
||||
p.url = url;
|
||||
p.downloaded = json_value(payload, "downloaded", (size_t)0);
|
||||
p.total = json_value(payload, "total", (size_t)0);
|
||||
update_download_progress(name, p, false);
|
||||
}
|
||||
} break;
|
||||
case SERVER_STATE_LOADING:
|
||||
{
|
||||
update_status(name, {
|
||||
@@ -1373,90 +1366,6 @@ bool server_child::is_child() {
|
||||
return router_port != nullptr;
|
||||
}
|
||||
|
||||
server_child_mode server_child::get_mode() {
|
||||
const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE");
|
||||
std::string mode_str(mode ? mode : "");
|
||||
if (mode_str == "download") {
|
||||
return SERVER_CHILD_MODE_DOWNLOAD;
|
||||
} else {
|
||||
return SERVER_CHILD_MODE_NORMAL;
|
||||
}
|
||||
}
|
||||
|
||||
struct server_download_state : public common_download_callback {
|
||||
server_child * self;
|
||||
std::function<bool()> should_stop;
|
||||
std::atomic<int64_t> last_progress_time{0}; // multiple files downloading in different threads
|
||||
bool is_ok = false;
|
||||
|
||||
server_download_state(server_child * s) : self(s) {}
|
||||
|
||||
bool run(common_params & params) {
|
||||
try {
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = params.model.get_name();
|
||||
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
|
||||
is_ok = false;
|
||||
}
|
||||
return is_ok;
|
||||
}
|
||||
void on_progress(const common_download_progress & p) {
|
||||
json data = {
|
||||
{"url", p.url},
|
||||
{"downloaded", p.downloaded},
|
||||
{"total", p.total},
|
||||
};
|
||||
self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data);
|
||||
}
|
||||
void on_start(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_update(const common_download_progress & p) override {
|
||||
int64_t now = ggml_time_ms();
|
||||
// throttle progress updates to avoid flooding logs
|
||||
if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) {
|
||||
on_progress(p);
|
||||
last_progress_time.store(now, std::memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
void on_done(const common_download_progress & p, bool) override {
|
||||
on_progress(p);
|
||||
}
|
||||
bool is_cancelled() const override {
|
||||
return should_stop ? should_stop() : false;
|
||||
}
|
||||
};
|
||||
|
||||
int server_child::run_download(common_params & params) {
|
||||
auto cancelled = std::make_shared<std::atomic<bool>>(false);
|
||||
|
||||
// monitor stdin for cancellation command from the router
|
||||
std::thread signal_thread = setup([cancelled](int) {
|
||||
cancelled->store(true, std::memory_order_relaxed);
|
||||
});
|
||||
|
||||
server_download_state dl(this);
|
||||
dl.should_stop = [cancelled]() {
|
||||
return cancelled->load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
bool ok = dl.run(params);
|
||||
|
||||
notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), {
|
||||
{"result", ok ? "download_finished" : "download_failed"},
|
||||
});
|
||||
|
||||
// router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result
|
||||
if (signal_thread.joinable()) {
|
||||
signal_thread.join();
|
||||
}
|
||||
|
||||
SRV_INF("download completed %s\n", ok ? "successfully" : "with errors");
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::thread server_child::setup(const std::function<void(int)> & shutdown_handler) {
|
||||
// setup thread for monitoring stdin
|
||||
return std::thread([shutdown_handler]() {
|
||||
@@ -1730,7 +1639,7 @@ void server_models_routes::init_routes() {
|
||||
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
if (!model->is_running()) {
|
||||
res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
@@ -1771,9 +1680,8 @@ void server_models_routes::init_routes() {
|
||||
|
||||
model.hf_repo = name;
|
||||
opts.bearer_token = params.hf_token;
|
||||
// note: we only check main model, no need sidecar here
|
||||
opts.download_mmproj = false;
|
||||
opts.download_mtp = false;
|
||||
opts.download_mmproj = true;
|
||||
opts.download_mtp = true;
|
||||
|
||||
// first, only check if the model is valid and can be downloaded
|
||||
opts.skip_download = true;
|
||||
@@ -1794,21 +1702,10 @@ void server_models_routes::init_routes() {
|
||||
throw std::invalid_argument("model validation failed, unable to download");
|
||||
}
|
||||
|
||||
// reject if model already exists
|
||||
if (models.has_model(name)) {
|
||||
throw std::invalid_argument("model '" + name + "' already exists");
|
||||
}
|
||||
|
||||
// then, proceed with the actual download
|
||||
opts.skip_download = false;
|
||||
SRV_INF("starting download for model '%s'\n", name.c_str());
|
||||
{
|
||||
server_models::load_options load_opts;
|
||||
load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD;
|
||||
load_opts.custom_meta = server_model_meta{};
|
||||
load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE;
|
||||
load_opts.custom_meta->name = name;
|
||||
models.load(name, load_opts);
|
||||
}
|
||||
models.download(std::move(model), std::move(opts));
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
@@ -1822,7 +1719,10 @@ void server_models_routes::init_routes() {
|
||||
throw std::invalid_argument("model must be a non-empty string");
|
||||
}
|
||||
|
||||
models.remove(name); // throws on error
|
||||
bool ok = models.remove(name);
|
||||
if (!ok) {
|
||||
throw std::runtime_error("failed to remove model '" + name + "'");
|
||||
}
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
|
||||
@@ -40,11 +40,6 @@ enum server_model_source {
|
||||
SERVER_MODEL_SOURCE_CACHE,
|
||||
};
|
||||
|
||||
enum server_child_mode {
|
||||
SERVER_CHILD_MODE_NORMAL, // load the model and run normally
|
||||
SERVER_CHILD_MODE_DOWNLOAD, // download the model and exit
|
||||
};
|
||||
|
||||
static std::string server_model_status_to_string(server_model_status status) {
|
||||
switch (status) {
|
||||
case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading";
|
||||
@@ -110,6 +105,7 @@ private:
|
||||
std::shared_ptr<server_subproc> subproc; // shared between main thread and monitoring thread
|
||||
std::thread th;
|
||||
server_model_meta meta;
|
||||
FILE * stdin_file = nullptr;
|
||||
};
|
||||
|
||||
std::mutex mutex;
|
||||
@@ -165,19 +161,16 @@ public:
|
||||
// return a copy of all model metadata (thread-safe)
|
||||
std::vector<server_model_meta> get_all_meta();
|
||||
|
||||
struct load_options {
|
||||
server_child_mode mode = SERVER_CHILD_MODE_NORMAL;
|
||||
// used for spawning a downloading child process
|
||||
std::optional<server_model_meta> custom_meta = std::nullopt;
|
||||
};
|
||||
|
||||
// load and unload model instances
|
||||
// these functions are thread-safe
|
||||
void load(const std::string & name);
|
||||
void load(const std::string & name, const load_options & opts);
|
||||
void unload(const std::string & name);
|
||||
void unload_all();
|
||||
|
||||
// download a new model, progress is reported via SSE
|
||||
// to stop the download, call unload()
|
||||
void download(common_params_model && model, common_download_opts && opts);
|
||||
|
||||
struct update_status_args {
|
||||
server_model_status status;
|
||||
int exit_code = 0; // only valid if status == UNLOADED
|
||||
@@ -220,12 +213,9 @@ public:
|
||||
struct server_child {
|
||||
// serializes the notify_to_router writes
|
||||
std::mutex mtx_stdout;
|
||||
std::atomic<bool> is_finished_downloading = false; // set by run_download
|
||||
|
||||
// return true if the current process is a child server instance
|
||||
bool is_child();
|
||||
server_child_mode get_mode();
|
||||
int run_download(common_params & params);
|
||||
|
||||
// register the shutdown_handler to be called by the router
|
||||
// return the monitoring thread (to be joined by the caller)
|
||||
|
||||
@@ -591,11 +591,10 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() {
|
||||
|
||||
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
|
||||
output.push_back(json {
|
||||
{"id", "fc_" + tool_call.id},
|
||||
{"type", "function_call"},
|
||||
{"status", "completed"},
|
||||
{"arguments", tool_call.arguments},
|
||||
{"call_id", "call_" + tool_call.id},
|
||||
{"call_id", "fc_" + tool_call.id},
|
||||
{"name", tool_call.name},
|
||||
});
|
||||
}
|
||||
@@ -691,11 +690,10 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() {
|
||||
|
||||
for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) {
|
||||
const json output_item = {
|
||||
{"id", "fc_" + tool_call.id},
|
||||
{"type", "function_call"},
|
||||
{"status", "completed"},
|
||||
{"arguments", tool_call.arguments},
|
||||
{"call_id", "call_" + tool_call.id},
|
||||
{"call_id", "fc_" + tool_call.id},
|
||||
{"name", tool_call.name}
|
||||
};
|
||||
server_sent_events.push_back(json {
|
||||
@@ -1279,9 +1277,8 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() {
|
||||
{"data", json {
|
||||
{"type", "response.output_item.added"},
|
||||
{"item", json {
|
||||
{"id", "fc_" + diff.tool_call_delta.id},
|
||||
{"arguments", ""},
|
||||
{"call_id", "call_" + diff.tool_call_delta.id},
|
||||
{"call_id", "fc_" + diff.tool_call_delta.id},
|
||||
{"name", diff.tool_call_delta.name},
|
||||
{"type", "function_call"},
|
||||
{"status", "in_progress"},
|
||||
|
||||
@@ -62,6 +62,9 @@ struct task_params {
|
||||
|
||||
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
|
||||
|
||||
// number of prompt tokens before the latest user message
|
||||
int32_t n_before_user = -1;
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
@@ -89,9 +92,6 @@ struct task_params {
|
||||
// per-request parameters for chat parsing
|
||||
common_chat_parser_params chat_parser_params;
|
||||
|
||||
// message spans for checkpointing
|
||||
common_chat_msg_spans message_spans;
|
||||
|
||||
// Embeddings
|
||||
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
||||
|
||||
+1
-12
@@ -134,7 +134,6 @@ int llama_server(int argc, char ** argv) {
|
||||
//
|
||||
|
||||
// register API routes
|
||||
server_child child; // only used in non-router mode
|
||||
server_routes routes(params, ctx_server);
|
||||
server_tools tools;
|
||||
|
||||
@@ -255,21 +254,11 @@ int llama_server(int argc, char ** argv) {
|
||||
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
|
||||
}
|
||||
|
||||
//
|
||||
// Handle downloading model
|
||||
//
|
||||
|
||||
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
return child.run_download(params);
|
||||
} else if (!is_router_server) {
|
||||
// single-model mode (NOT spawned by router)
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
//
|
||||
|
||||
server_child child; // only used in non-router mode
|
||||
std::function<void()> clean_up;
|
||||
|
||||
if (is_router_server) {
|
||||
|
||||
@@ -257,25 +257,14 @@ def test_router_reload_models():
|
||||
|
||||
|
||||
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
|
||||
MODEL_DOWNLOAD_TIMEOUT = 30
|
||||
MODEL_DOWNLOAD_TIMEOUT = 300
|
||||
|
||||
|
||||
def _listen_sse(
|
||||
server: ServerProcess, collected: list, stop: threading.Event, ready: threading.Event | None = None
|
||||
):
|
||||
"""Collect /models/sse events into `collected` until `stop` is set.
|
||||
|
||||
When `ready` is provided, it is set once the streaming response is open,
|
||||
i.e. the server has accepted the connection and registered us as a
|
||||
subscriber. Callers that trigger one-shot events (e.g. download_finished)
|
||||
must wait on `ready` before acting, otherwise the event can be broadcast
|
||||
before this client is subscribed and be lost.
|
||||
"""
|
||||
def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event):
|
||||
"""Collect /models/sse events into `collected` until `stop` is set."""
|
||||
url = f"http://{server.server_host}:{server.server_port}/models/sse"
|
||||
try:
|
||||
with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp:
|
||||
if ready is not None:
|
||||
ready.set()
|
||||
for line_bytes in resp.iter_lines():
|
||||
if stop.is_set():
|
||||
break
|
||||
@@ -305,17 +294,11 @@ def test_router_download_model():
|
||||
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
sse_ready = threading.Event()
|
||||
sse_thread = threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
)
|
||||
sse_thread.start()
|
||||
|
||||
# wait for the SSE client to be subscribed before triggering the download,
|
||||
# otherwise the one-shot download_finished event can be broadcast before
|
||||
# this client is registered and be lost
|
||||
assert sse_ready.wait(10), "SSE client failed to connect"
|
||||
|
||||
# Trigger the download
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
@@ -345,17 +328,13 @@ def test_router_delete_model():
|
||||
|
||||
# Ensure the model exists (download it if needed)
|
||||
if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False):
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
sse_ready = threading.Event()
|
||||
threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
|
||||
).start()
|
||||
# subscribe before triggering the download so the one-shot
|
||||
# download_finished event is not lost (see test_router_download_model)
|
||||
assert sse_ready.wait(10), "SSE client failed to connect"
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
).start()
|
||||
finished = _wait_for_sse_event(
|
||||
sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT
|
||||
)
|
||||
|
||||
Vendored
-10
@@ -19,10 +19,6 @@ import type {
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiModelDataEntry,
|
||||
ApiModelLoadStage,
|
||||
ApiModelsSseProgress,
|
||||
ApiModelsSseData,
|
||||
ApiModelsSseEvent,
|
||||
ApiModelListResponse,
|
||||
ApiProcessingState,
|
||||
ApiRouterModelMeta,
|
||||
@@ -56,7 +52,6 @@ import type {
|
||||
// Model types
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
ModelLoadProgress,
|
||||
// Settings types
|
||||
SettingsChatServiceOptions,
|
||||
SettingsConfigValue,
|
||||
@@ -88,10 +83,6 @@ declare global {
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiModelDataEntry,
|
||||
ApiModelLoadStage,
|
||||
ApiModelsSseProgress,
|
||||
ApiModelsSseData,
|
||||
ApiModelsSseEvent,
|
||||
ApiModelListResponse,
|
||||
ApiProcessingState,
|
||||
ApiRouterModelMeta,
|
||||
@@ -129,7 +120,6 @@ declare global {
|
||||
// Model types
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
ModelLoadProgress,
|
||||
// Settings types
|
||||
SettingsChatServiceOptions,
|
||||
SettingsConfigValue,
|
||||
|
||||
+3
-12
@@ -10,7 +10,7 @@
|
||||
import { getMessageEditContext } from '$lib/contexts';
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
|
||||
import { copyToClipboard, deriveAgenticSections, modelLoadProgressText } from '$lib/utils';
|
||||
import { copyToClipboard, deriveAgenticSections } from '$lib/utils';
|
||||
import { AgenticSectionType } from '$lib/enums';
|
||||
import { REASONING_TAGS } from '$lib/constants/agentic';
|
||||
import { tick } from 'svelte';
|
||||
@@ -185,13 +185,6 @@
|
||||
let hasNoContent = $derived(!message?.content?.trim());
|
||||
let isActivelyProcessing = $derived(isCurrentlyLoading || isStreaming);
|
||||
|
||||
// during a router auto-load the message has no model yet, so target the selected one
|
||||
let loadTargetModel = $derived(message.model ?? modelsStore.selectedModelName);
|
||||
let modelLoadProgress = $derived(
|
||||
isRouter && loadTargetModel ? modelsStore.getLoadProgress(loadTargetModel) : null
|
||||
);
|
||||
let modelLoadingText = $derived(modelLoadProgressText(modelLoadProgress));
|
||||
|
||||
let showProcessingInfoTop = $derived(
|
||||
message?.role === MessageRole.ASSISTANT &&
|
||||
isActivelyProcessing &&
|
||||
@@ -227,8 +220,7 @@
|
||||
<div class="mt-6 w-full max-w-[48rem]" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{modelLoadingText ??
|
||||
processingState.getPromptProgressText() ??
|
||||
{processingState.getPromptProgressText() ??
|
||||
processingState.getProcessingMessage() ??
|
||||
'Processing...'}
|
||||
</span>
|
||||
@@ -260,8 +252,7 @@
|
||||
<div class="mt-4 w-full max-w-[48rem]" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{modelLoadingText ??
|
||||
processingState.getPromptProgressText() ??
|
||||
{processingState.getPromptProgressText() ??
|
||||
processingState.getProcessingMessage() ??
|
||||
'Processing...'}
|
||||
</span>
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
import type { ModelOption } from '$lib/types/models';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
import { modelLoadFraction, modelLoadProgressText } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
option: ModelOption;
|
||||
@@ -51,15 +50,11 @@
|
||||
(serverStatus === ServerModelStatus.LOADED || isSleeping) && !isOperationInProgress
|
||||
);
|
||||
let isLoading = $derived(serverStatus === ServerModelStatus.LOADING || isOperationInProgress);
|
||||
|
||||
let loadProgress = $derived(isLoading ? modelsStore.getLoadProgress(option.model) : null);
|
||||
let loadPercent = $derived(Math.round(modelLoadFraction(loadProgress) * 100));
|
||||
let loadTitle = $derived(modelLoadProgressText(loadProgress));
|
||||
</script>
|
||||
|
||||
<div
|
||||
class={[
|
||||
'group relative flex w-full items-center gap-2 rounded-sm p-2 text-left text-sm transition focus:outline-none',
|
||||
'group flex w-full items-center gap-2 rounded-sm p-2 text-left text-sm transition focus:outline-none',
|
||||
'cursor-pointer hover:bg-muted focus:bg-muted',
|
||||
(isSelected || isHighlighted) && 'bg-accent text-accent-foreground',
|
||||
!(isSelected || isHighlighted) && 'hover:bg-accent hover:text-accent-foreground',
|
||||
@@ -67,7 +62,6 @@
|
||||
]}
|
||||
role="option"
|
||||
aria-selected={isSelected || isHighlighted}
|
||||
title={loadTitle}
|
||||
tabindex="0"
|
||||
onclick={() => onSelect(option.id)}
|
||||
onmouseenter={onMouseEnter}
|
||||
@@ -194,15 +188,4 @@
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<div
|
||||
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-primary transition-[width] duration-200 ease-out"
|
||||
style="width: {loadPercent}%"
|
||||
></div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
export const API_MODELS = {
|
||||
LIST: '/v1/models',
|
||||
LOAD: '/models/load',
|
||||
UNLOAD: '/models/unload',
|
||||
SSE: '/models/sse'
|
||||
UNLOAD: '/models/unload'
|
||||
};
|
||||
|
||||
// chat completion routes, the control route drives realtime inference (e.g. end reasoning)
|
||||
|
||||
@@ -37,8 +37,6 @@ export * from './mcp-form';
|
||||
export * from './mcp-resource';
|
||||
export * from './message-export';
|
||||
export * from './model-id';
|
||||
export * from './model-loading';
|
||||
export * from './sse';
|
||||
export * from './precision';
|
||||
export * from './processing-info';
|
||||
export * from './pwa';
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
/**
|
||||
* Labels shown while a model loads, keyed by the stage reported on /models/sse.
|
||||
*/
|
||||
export const MODEL_LOAD_STAGE_LABELS: Record<ApiModelLoadStage, string> = {
|
||||
text_model: 'Loading weights',
|
||||
spec_model: 'Loading draft',
|
||||
mmproj_model: 'Loading projector'
|
||||
};
|
||||
|
||||
/**
|
||||
* Share of the bar reserved for each load phase after text_model.
|
||||
* text_model fills the rest, so a plain model reaches 100% on its own.
|
||||
*/
|
||||
export const MODEL_LOAD_TAIL_SHARE = 0.1;
|
||||
@@ -1,16 +0,0 @@
|
||||
/**
|
||||
* Server-sent events wire format, shared by the chat stream and the
|
||||
* /models/sse status feed (text/event-stream).
|
||||
*/
|
||||
|
||||
// blank line between two events
|
||||
export const SSE_RECORD_SEPARATOR = '\n\n';
|
||||
|
||||
// line break inside an event
|
||||
export const SSE_LINE_SEPARATOR = '\n';
|
||||
|
||||
// data field prefix, the value follows after an optional space
|
||||
export const SSE_DATA_PREFIX = 'data:';
|
||||
|
||||
// end-of-stream marker on the chat completion stream
|
||||
export const SSE_DONE_MARKER = '[DONE]';
|
||||
@@ -54,7 +54,7 @@ export {
|
||||
|
||||
export { ModelModality } from './model.enums';
|
||||
|
||||
export { ServerRole, ServerModelStatus, ServerModelsSseEventType } from './server.enums';
|
||||
export { ServerRole, ServerModelStatus } from './server.enums';
|
||||
|
||||
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings.enums';
|
||||
|
||||
|
||||
@@ -19,17 +19,3 @@ export enum ServerModelStatus {
|
||||
SLEEPING = 'sleeping',
|
||||
FAILED = 'failed'
|
||||
}
|
||||
|
||||
/**
|
||||
* /models/sse event type enum - discriminates the records broadcast on the
|
||||
* model status feed in ROUTER mode. Matches the event names emitted by
|
||||
* tools/server/server-models.cpp from the C++ server.
|
||||
*/
|
||||
export enum ServerModelsSseEventType {
|
||||
STATUS_CHANGE = 'status_change',
|
||||
MODEL_STATUS = 'model_status',
|
||||
STATUS_UPDATE = 'status_update',
|
||||
MODELS_RELOAD = 'models_reload',
|
||||
MODEL_REMOVE = 'model_remove',
|
||||
DOWNLOAD_PROGRESS = 'download_progress'
|
||||
}
|
||||
|
||||
@@ -10,10 +10,7 @@ import {
|
||||
SETTINGS_KEYS,
|
||||
API_CHAT,
|
||||
API_SLOTS,
|
||||
CONTROL_ACTION,
|
||||
SSE_LINE_SEPARATOR,
|
||||
SSE_DATA_PREFIX,
|
||||
SSE_DONE_MARKER
|
||||
CONTROL_ACTION
|
||||
} from '$lib/constants';
|
||||
import {
|
||||
AttachmentType,
|
||||
@@ -21,7 +18,8 @@ import {
|
||||
FileTypeAudio,
|
||||
MessageRole,
|
||||
MimeTypeAudio,
|
||||
ReasoningFormat
|
||||
ReasoningFormat,
|
||||
UrlProtocol
|
||||
} from '$lib/enums';
|
||||
import type {
|
||||
ApiChatMessageContentPart,
|
||||
@@ -644,15 +642,15 @@ export class ChatService {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
chunk += decoder.decode(value, { stream: true });
|
||||
const lines = chunk.split(SSE_LINE_SEPARATOR);
|
||||
const lines = chunk.split('\n');
|
||||
chunk = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
if (line.startsWith(SSE_DATA_PREFIX)) {
|
||||
const data = line.slice(SSE_DATA_PREFIX.length).trim();
|
||||
if (data === SSE_DONE_MARKER) {
|
||||
if (line.startsWith(UrlProtocol.DATA)) {
|
||||
const data = line.slice(6);
|
||||
if (data === '[DONE]') {
|
||||
streamFinished = true;
|
||||
|
||||
continue;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { base } from '$app/paths';
|
||||
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { ServerModelStatus, ServerModelsSseEventType, ModelModality } from '$lib/enums';
|
||||
import { ServerModelStatus, ModelModality } from '$lib/enums';
|
||||
import { ModelsService } from '$lib/services/models.service';
|
||||
import { PropsService } from '$lib/services/props.service';
|
||||
import { serverStore, isRouterMode } from '$lib/stores/server.svelte';
|
||||
@@ -9,15 +8,11 @@ import {
|
||||
detectThinkingSupport,
|
||||
detectThinkingSupportWithReason
|
||||
} from '$lib/utils/chat-template-thinking-detector';
|
||||
import { TTLCache, getAuthHeaders } from '$lib/utils';
|
||||
import { TTLCache } from '$lib/utils';
|
||||
import {
|
||||
MODEL_PROPS_CACHE_TTL_MS,
|
||||
MODEL_PROPS_CACHE_MAX_ENTRIES,
|
||||
FAVORITE_MODELS_LOCALSTORAGE_KEY,
|
||||
API_MODELS,
|
||||
SSE_RECORD_SEPARATOR,
|
||||
SSE_LINE_SEPARATOR,
|
||||
SSE_DATA_PREFIX
|
||||
FAVORITE_MODELS_LOCALSTORAGE_KEY
|
||||
} from '$lib/constants';
|
||||
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
@@ -60,15 +55,6 @@ class ModelsStore {
|
||||
private modelUsage = $state<Map<string, SvelteSet<string>>>(new Map());
|
||||
private modelLoadingStates = new SvelteMap<string, boolean>();
|
||||
|
||||
// /models/sse feed state, the single source of truth for status and load progress
|
||||
private statusAbort: AbortController | null = null;
|
||||
private statusReaderActive = false;
|
||||
private loadProgress = new SvelteMap<string, ModelLoadProgress>();
|
||||
private statusWaiters = new Map<
|
||||
string,
|
||||
{ target: ServerModelStatus; resolve: () => void; reject: (e: Error) => void }
|
||||
>();
|
||||
|
||||
favoriteModelIds = $state<Set<string>>(this.loadFavoritesFromStorage());
|
||||
|
||||
/**
|
||||
@@ -545,8 +531,7 @@ class ModelsStore {
|
||||
* 1. Model from active conversation's last assistant response (if loaded)
|
||||
* 2. Model from active conversation's last assistant response (if not loaded)
|
||||
* 3. First loaded model (not from active conversation)
|
||||
* 4. A favorite model
|
||||
* 5. First available model
|
||||
* 4. First available model
|
||||
*/
|
||||
async ensureFirstModelSelected(): Promise<void> {
|
||||
if (this.selectedModelName) return;
|
||||
@@ -575,13 +560,6 @@ class ModelsStore {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try loading a favorite model
|
||||
const favorite = this.favoriteModelIds.values().next()?.value
|
||||
if (favorite) {
|
||||
await this.selectModelById(favorite);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to the first available model
|
||||
await this.selectModelById(availableModels[0].id);
|
||||
}
|
||||
@@ -648,218 +626,49 @@ class ModelsStore {
|
||||
*
|
||||
*/
|
||||
|
||||
// reconnect delay after the feed drops or the server is not ready yet
|
||||
private static readonly SSE_RECONNECT_MS = 1000;
|
||||
/**
|
||||
* WORKAROUND: Polling for model status after load/unload operations.
|
||||
*
|
||||
* Currently, `/models/load` and `/models/unload` return success before
|
||||
* the operation actually completes on the server.
|
||||
*
|
||||
* TODO: Remove polling once llama-server properly waits for the operation
|
||||
* to complete before returning success.
|
||||
*/
|
||||
|
||||
private static readonly STATUS_POLL_INTERVAL = 500;
|
||||
|
||||
/**
|
||||
* Open the /models/sse feed and keep it live with auto reconnect.
|
||||
* Idempotent and router mode only. The feed drives status and progress,
|
||||
* so it replaces any post-operation polling.
|
||||
* Poll for expected model status after load/unload operation.
|
||||
* Keeps polling until the model reaches the expected status or fails.
|
||||
*/
|
||||
subscribeStatus(): void {
|
||||
if (this.statusReaderActive) return;
|
||||
if (!isRouterMode()) return;
|
||||
private async pollForModelStatus(
|
||||
modelId: string,
|
||||
expectedStatus: ServerModelStatus
|
||||
): Promise<void> {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
await this.fetchRouterModels();
|
||||
|
||||
this.statusReaderActive = true;
|
||||
this.statusAbort = new AbortController();
|
||||
void this.runStatusReader(this.statusAbort.signal);
|
||||
}
|
||||
const currentStatus = this.getModelStatus(modelId);
|
||||
if (currentStatus === expectedStatus) return;
|
||||
|
||||
/**
|
||||
* Close the /models/sse feed and drop transient progress.
|
||||
*/
|
||||
unsubscribeStatus(): void {
|
||||
this.statusReaderActive = false;
|
||||
this.statusAbort?.abort();
|
||||
this.statusAbort = null;
|
||||
this.loadProgress.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Current load progress for a model, or null when not loading.
|
||||
*/
|
||||
getLoadProgress(modelId: string): ModelLoadProgress | null {
|
||||
return this.loadProgress.get(modelId) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read the feed and reconnect until unsubscribed. Splits the byte stream
|
||||
* into SSE records on the blank line boundary.
|
||||
*/
|
||||
private async runStatusReader(signal: AbortSignal): Promise<void> {
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (!signal.aborted) {
|
||||
try {
|
||||
const response = await fetch(`${base}${API_MODELS.SSE}`, {
|
||||
headers: getAuthHeaders(),
|
||||
signal
|
||||
});
|
||||
|
||||
if (response.ok && response.body) {
|
||||
const reader = response.body.getReader();
|
||||
let buffer = '';
|
||||
|
||||
while (!signal.aborted) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
let boundary = buffer.indexOf(SSE_RECORD_SEPARATOR);
|
||||
while (boundary !== -1) {
|
||||
this.handleStatusRecord(buffer.slice(0, boundary));
|
||||
buffer = buffer.slice(boundary + SSE_RECORD_SEPARATOR.length);
|
||||
boundary = buffer.indexOf(SSE_RECORD_SEPARATOR);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// network drop or abort falls through to the reconnect delay
|
||||
if (currentStatus === ServerModelStatus.FAILED) {
|
||||
throw new Error(
|
||||
`Model failed to ${expectedStatus === ServerModelStatus.LOADED ? 'load' : 'unload'}`
|
||||
);
|
||||
}
|
||||
|
||||
if (signal.aborted) return;
|
||||
if (
|
||||
expectedStatus === ServerModelStatus.LOADED &&
|
||||
currentStatus === ServerModelStatus.UNLOADED &&
|
||||
attempt > 2
|
||||
) {
|
||||
throw new Error('Model was unloaded unexpectedly during loading');
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, ModelsStore.SSE_RECONNECT_MS));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse one SSE record. The payload rides in the data lines as a JSON
|
||||
* envelope that carries its own model, event and data fields.
|
||||
*/
|
||||
private handleStatusRecord(record: string): void {
|
||||
const payload = record
|
||||
.split(SSE_LINE_SEPARATOR)
|
||||
.filter((line) => line.startsWith(SSE_DATA_PREFIX))
|
||||
.map((line) => line.slice(SSE_DATA_PREFIX.length).trim())
|
||||
.join(SSE_LINE_SEPARATOR);
|
||||
|
||||
if (payload.length === 0) return;
|
||||
|
||||
let envelope: ApiModelsSseEvent;
|
||||
try {
|
||||
envelope = JSON.parse(payload);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
this.applyStatusEvent(envelope);
|
||||
}
|
||||
|
||||
/**
|
||||
* Route one feed record by event kind. Only the status_* events carry a
|
||||
* status payload, models_reload triggers a list refresh, model_remove drops
|
||||
* the row, download_* belong to the download surface, not here.
|
||||
*/
|
||||
private applyStatusEvent(event: ApiModelsSseEvent): void {
|
||||
switch (event.event) {
|
||||
case ServerModelsSseEventType.STATUS_CHANGE:
|
||||
case ServerModelsSseEventType.MODEL_STATUS:
|
||||
case ServerModelsSseEventType.STATUS_UPDATE:
|
||||
this.applyModelStatus(event);
|
||||
break;
|
||||
case ServerModelsSseEventType.MODELS_RELOAD:
|
||||
void this.fetchRouterModels();
|
||||
break;
|
||||
case ServerModelsSseEventType.MODEL_REMOVE:
|
||||
this.removeRouterModel(event.model);
|
||||
break;
|
||||
case ServerModelsSseEventType.DOWNLOAD_PROGRESS:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a status envelope: update the model row, track or clear progress,
|
||||
* settle any pending load or unload awaiter.
|
||||
*/
|
||||
private applyModelStatus(event: ApiModelsSseEvent): void {
|
||||
const model = event.model;
|
||||
const data = event.data;
|
||||
if (!model || !data?.status) return;
|
||||
|
||||
const status = data.status;
|
||||
|
||||
this.setRouterModelStatus(model, status);
|
||||
|
||||
if (status === ServerModelStatus.LOADING) {
|
||||
if (data.progress) this.loadProgress.set(model, data.progress);
|
||||
} else {
|
||||
this.loadProgress.delete(model);
|
||||
}
|
||||
|
||||
if (status === ServerModelStatus.LOADED) {
|
||||
void this.updateModelModalities(model);
|
||||
}
|
||||
|
||||
const failed =
|
||||
status === ServerModelStatus.FAILED ||
|
||||
(status === ServerModelStatus.UNLOADED && (data.exit_code ?? 0) !== 0);
|
||||
|
||||
if (failed) {
|
||||
this.rejectStatus(model, new Error(`Model failed: ${this.toDisplayName(model)}`));
|
||||
return;
|
||||
}
|
||||
|
||||
this.settleStatus(model, status);
|
||||
}
|
||||
|
||||
/**
|
||||
* Drop a model row reported gone by the feed and settle its awaiters.
|
||||
*/
|
||||
private removeRouterModel(modelId: string): void {
|
||||
if (this.routerModels.findIndex((m) => m.id === modelId) === -1) return;
|
||||
|
||||
this.routerModels = this.routerModels.filter((m) => m.id !== modelId);
|
||||
this.loadProgress.delete(modelId);
|
||||
this.rejectStatus(modelId, new Error(`Model removed: ${this.toDisplayName(modelId)}`));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update one model row status in place, reassigning to trigger reactivity.
|
||||
*/
|
||||
private setRouterModelStatus(modelId: string, status: ServerModelStatus): void {
|
||||
const idx = this.routerModels.findIndex((m) => m.id === modelId);
|
||||
if (idx === -1) return;
|
||||
|
||||
const current = this.routerModels[idx];
|
||||
if (current.status.value === status) return;
|
||||
|
||||
const next = [...this.routerModels];
|
||||
next[idx] = { ...current, status: { ...current.status, value: status } };
|
||||
this.routerModels = next;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register an awaiter that resolves when the feed reports target status.
|
||||
* One operation runs per model at a time, so one awaiter per model is kept.
|
||||
*/
|
||||
private waitForStatus(modelId: string, target: ServerModelStatus): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.statusWaiters.set(modelId, { target, resolve, reject });
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve and drop the awaiter when the model reaches its target status.
|
||||
*/
|
||||
private settleStatus(modelId: string, status: ServerModelStatus): void {
|
||||
const waiter = this.statusWaiters.get(modelId);
|
||||
if (waiter && waiter.target === status) {
|
||||
this.statusWaiters.delete(modelId);
|
||||
waiter.resolve();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reject and drop the awaiter for a model.
|
||||
*/
|
||||
private rejectStatus(modelId: string, error: Error): void {
|
||||
const waiter = this.statusWaiters.get(modelId);
|
||||
if (waiter) {
|
||||
this.statusWaiters.delete(modelId);
|
||||
waiter.reject(error);
|
||||
attempt++;
|
||||
await new Promise((resolve) => setTimeout(resolve, ModelsStore.STATUS_POLL_INTERVAL));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -870,18 +679,12 @@ class ModelsStore {
|
||||
this.modelLoadingStates.set(modelId, true);
|
||||
this.error = null;
|
||||
|
||||
// the feed drives completion, so it must be live before the request
|
||||
this.subscribeStatus();
|
||||
|
||||
const reachedLoaded = this.waitForStatus(modelId, ServerModelStatus.LOADED);
|
||||
reachedLoaded.catch(() => {});
|
||||
|
||||
try {
|
||||
await ModelsService.load(modelId);
|
||||
await reachedLoaded;
|
||||
await this.pollForModelStatus(modelId, ServerModelStatus.LOADED);
|
||||
await this.updateModelModalities(modelId);
|
||||
toast.success(`Model loaded: ${this.toDisplayName(modelId)}`);
|
||||
} catch (error) {
|
||||
this.rejectStatus(modelId, error instanceof Error ? error : new Error('load failed'));
|
||||
this.error = error instanceof Error ? error.message : 'Failed to load model';
|
||||
toast.error(`Failed to load model: ${this.toDisplayName(modelId)}`);
|
||||
throw error;
|
||||
@@ -897,17 +700,11 @@ class ModelsStore {
|
||||
this.modelLoadingStates.set(modelId, true);
|
||||
this.error = null;
|
||||
|
||||
this.subscribeStatus();
|
||||
|
||||
const reachedUnloaded = this.waitForStatus(modelId, ServerModelStatus.UNLOADED);
|
||||
reachedUnloaded.catch(() => {});
|
||||
|
||||
try {
|
||||
await ModelsService.unload(modelId);
|
||||
await reachedUnloaded;
|
||||
await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED);
|
||||
toast.info(`Model unloaded: ${this.toDisplayName(modelId)}`);
|
||||
} catch (error) {
|
||||
this.rejectStatus(modelId, error instanceof Error ? error : new Error('unload failed'));
|
||||
this.error = error instanceof Error ? error.message : 'Failed to unload model';
|
||||
toast.error(`Failed to unload model: ${this.toDisplayName(modelId)}`);
|
||||
throw error;
|
||||
@@ -986,9 +783,6 @@ class ModelsStore {
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.unsubscribeStatus();
|
||||
this.statusWaiters.forEach((waiter) => waiter.reject(new Error('Models store cleared')));
|
||||
this.statusWaiters.clear();
|
||||
this.models = [];
|
||||
this.routerModels = [];
|
||||
this.loading = false;
|
||||
|
||||
Vendored
+1
-47
@@ -1,10 +1,4 @@
|
||||
import type {
|
||||
ContentPartType,
|
||||
FileTypeAudio,
|
||||
ServerModelStatus,
|
||||
ServerModelsSseEventType,
|
||||
ServerRole
|
||||
} from '$lib/enums';
|
||||
import type { ContentPartType, FileTypeAudio, ServerModelStatus, ServerRole } from '$lib/enums';
|
||||
import type { ChatMessagePromptProgress, ChatRole } from './chat';
|
||||
|
||||
export type AudioInputFormat = FileTypeAudio.WAV | FileTypeAudio.MP3;
|
||||
@@ -102,46 +96,6 @@ export interface ApiModelDataEntry {
|
||||
meta?: Record<string, unknown> | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load stage reported by the /models/sse feed, in load order.
|
||||
*/
|
||||
export type ApiModelLoadStage = 'text_model' | 'spec_model' | 'mmproj_model';
|
||||
|
||||
/**
|
||||
* Load progress snapshot: the full ordered stage plan, the active stage,
|
||||
* and its fractional value (0.0 -> 1.0).
|
||||
*/
|
||||
export interface ApiModelsSseProgress {
|
||||
stages: ApiModelLoadStage[];
|
||||
current: ApiModelLoadStage;
|
||||
value: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Status payload carried by a /models/sse envelope.
|
||||
* exit_code appears on unload.
|
||||
*/
|
||||
export interface ApiModelsSseData {
|
||||
status: ServerModelStatus;
|
||||
progress?: ApiModelsSseProgress;
|
||||
exit_code?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Event kind multiplexed on the /models/sse feed.
|
||||
* Only the status_* events carry a status payload, models_reload signals a
|
||||
* full list refresh, model_remove drops a row, download_* drive download UI.
|
||||
*/
|
||||
/**
|
||||
* One /models/sse record. event discriminates the kind, model names the
|
||||
* target instance, data carries the status payload when present.
|
||||
*/
|
||||
export interface ApiModelsSseEvent {
|
||||
model: string;
|
||||
event: ServerModelsSseEventType;
|
||||
data: ApiModelsSseData;
|
||||
}
|
||||
|
||||
export interface ApiModelDetails {
|
||||
name: string;
|
||||
model: string;
|
||||
|
||||
@@ -11,10 +11,6 @@ export type {
|
||||
ApiChatMessageData,
|
||||
ApiModelStatus,
|
||||
ApiModelDataEntry,
|
||||
ApiModelLoadStage,
|
||||
ApiModelsSseProgress,
|
||||
ApiModelsSseData,
|
||||
ApiModelsSseEvent,
|
||||
ApiModelDetails,
|
||||
ApiModelListResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
@@ -74,12 +70,7 @@ export type {
|
||||
} from './database';
|
||||
|
||||
// Model types
|
||||
export type {
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
ModelLoadProgress,
|
||||
ModalityCapabilities
|
||||
} from './models';
|
||||
export type { ModelModalities, ModelOption, ModalityCapabilities } from './models';
|
||||
|
||||
// Settings types
|
||||
export type {
|
||||
|
||||
Vendored
+1
-12
@@ -1,4 +1,4 @@
|
||||
import type { ApiModelDataEntry, ApiModelDetails, ApiModelLoadStage } from '$lib/types/api';
|
||||
import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api';
|
||||
|
||||
export interface ModelModalities {
|
||||
vision: boolean;
|
||||
@@ -20,17 +20,6 @@ export interface ModelOption {
|
||||
tags?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Ephemeral UI-only load progress for one model instance.
|
||||
* Lives only while a load runs, driven by the /models/sse feed.
|
||||
* stage is absent until the feed reports its first stage.
|
||||
*/
|
||||
export interface ModelLoadProgress {
|
||||
stages: ApiModelLoadStage[];
|
||||
current: ApiModelLoadStage;
|
||||
value: number;
|
||||
}
|
||||
|
||||
export interface ParsedModelId {
|
||||
raw: string;
|
||||
orgName: string | null;
|
||||
|
||||
@@ -44,9 +44,6 @@ export { buildProxiedUrl, buildProxiedHeaders } from './cors-proxy';
|
||||
// URL utilities
|
||||
export { extractRootDomain, sanitizeExternalUrl } from './url';
|
||||
|
||||
// Progress helpers
|
||||
export { modelLoadFraction, modelLoadProgressText } from './progress';
|
||||
|
||||
// Conversation utilities
|
||||
export { createMessageCountMap, getMessageCount } from './conversation-utils';
|
||||
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
/**
|
||||
* Model load progress helpers for the /models/sse surfaces
|
||||
* (selector row and chat message).
|
||||
*/
|
||||
|
||||
import { MODEL_LOAD_STAGE_LABELS, MODEL_LOAD_TAIL_SHARE } from '$lib/constants';
|
||||
|
||||
/**
|
||||
* Human label for a model load stage.
|
||||
*/
|
||||
export function modelLoadStageLabel(stage: ApiModelLoadStage): string {
|
||||
return MODEL_LOAD_STAGE_LABELS[stage];
|
||||
}
|
||||
|
||||
/**
|
||||
* Overall load fraction (0.0 -> 1.0) across the declared stage plan.
|
||||
* text_model fills [0, 1 - tail], each later phase owns one tail slice.
|
||||
*/
|
||||
export function modelLoadFraction(progress: ModelLoadProgress | null): number {
|
||||
if (!progress) return 0;
|
||||
|
||||
const { stages, current, value } = progress;
|
||||
const tailCount = Math.max(stages.length - 1, 0);
|
||||
const textCeiling = 1 - tailCount * MODEL_LOAD_TAIL_SHARE;
|
||||
const idx = stages.indexOf(current);
|
||||
|
||||
if (idx <= 0) {
|
||||
return value * textCeiling;
|
||||
}
|
||||
|
||||
return textCeiling + (idx - 1 + value) * MODEL_LOAD_TAIL_SHARE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Single line describing load progress: active stage label and overall percent.
|
||||
* Returns null when there is no progress to show.
|
||||
*/
|
||||
export function modelLoadProgressText(progress: ModelLoadProgress | null): string | null {
|
||||
if (!progress) return null;
|
||||
|
||||
const label = modelLoadStageLabel(progress.current);
|
||||
return `${label} ${Math.round(modelLoadFraction(progress) * 100)}%`;
|
||||
}
|
||||
@@ -230,20 +230,6 @@
|
||||
}
|
||||
});
|
||||
|
||||
// Live model status and load progress via the /models/sse feed (router mode)
|
||||
$effect(() => {
|
||||
if (!browser) return;
|
||||
if (!isRouterMode()) return;
|
||||
|
||||
untrack(() => {
|
||||
modelsStore.subscribeStatus();
|
||||
});
|
||||
|
||||
return () => {
|
||||
modelsStore.unsubscribeStatus();
|
||||
};
|
||||
});
|
||||
|
||||
// Background MCP server health checks on app load
|
||||
// Fetch enabled servers from settings and run health checks in background
|
||||
$effect(() => {
|
||||
|
||||
Reference in New Issue
Block a user