Compare commits

..

18 Commits

Author SHA1 Message Date
Georgi Gerganov abd4d0bc4f speculative : update default params (#11954)
* speculative : update default params

* speculative : do not discard the last drafted token
2025-02-19 13:29:42 +02:00
Daniel Bevenius 9626d9351a llama : fix indentation in llama-grammar [no ci] (#11943)
This commit adjusts the indentation for the functions `parse_sequence`
and `parse_rule` in src/llama-grammar.cpp.

The motivation is consistency and improve readability.
2025-02-19 06:16:23 +01:00
igardev b58934c183 server : (webui) Enable communication with parent html (if webui is in iframe) (#11940)
* Webui: Enable communication with parent html (if webui is in iframe):
- Listens for "setText" command from parent with "text" and "context" fields. "text" is set in inputMsg, "context" is used as hidden context on the following requests to the llama.cpp server
- On pressing na Escape button sends command "escapePressed" to the parent

Example handling from the parent html side:
- Send command "setText" from parent html to webui in iframe:
const iframe = document.getElementById('askAiIframe');
if (iframe) {
	iframe.contentWindow.postMessage({ command: 'setText', text: text, context: context }, '*');
}

- Listen for Escape key from webui on parent html:
// Listen for escape key event in the iframe
window.addEventListener('keydown', (event) => {
	if (event.key === 'Escape') {
		// Process case when Escape is pressed inside webui
	}
});

* Move the extraContext from storage to app.context.

* Fix formatting.

* add Message.extra

* format + build

* MessageExtraContext

* build

* fix display

* rm console.log

---------

Co-authored-by: igardev <ivailo.gardev@akros.ch>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-02-18 23:01:44 +01:00
Olivier Chafik 63e489c025 tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900)
* tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type

* addressed clang-tidy lints in [test-]chat.*

* rm minja deps from util & common & move it to common/minja/

* add name & tool_call_id to common_chat_msg

* add common_chat_tool

* added json <-> tools, msgs conversions to chat.h

* fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens)

* fix deepseek r1 slow test (no longer <think> opening w/ new template)

* allow empty tools w/ auto + grammar

* fix & test server grammar & json_schema params w/ & w/o --jinja
2025-02-18 18:03:23 +00:00
Xuan-Son Nguyen 63ac128563 server : add TEI API format for /rerank endpoint (#11942)
* server : add TEI API format for /rerank endpoint

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* fix

* also gitignore examples/server/*.gz.hpp

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-02-18 14:21:41 +01:00
MoonRide303 5137da7b8c scripts: corrected encoding when getting chat template (#11866) (#11907)
Signed-off-by: MoonRide303 <moonride303@gmail.com>
2025-02-18 10:30:16 +01:00
xiaobing318 09aaf4f1f5 docs : Fix duplicated file extension in test command (#11935)
This commit fixes an issue in the llama.cpp project where the command for testing the llama-server object contained a duplicated file extension. The original command was:
./tests.sh unit/test_chat_completion.py.py -v -x
It has been corrected to:
./tests.sh unit/test_chat_completion.py -v -x
This change ensures that the test script correctly locates and executes the intended test file, preventing test failures due to an incorrect file name.
2025-02-18 10:12:49 +01:00
Johannes Gäßler 73e2ed3ce3 CUDA: use async data loading for FlashAttention (#11894)
* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-02-17 14:03:24 +01:00
Eve f7b1116af1 update release requirements (#11897) 2025-02-17 12:20:23 +01:00
Antoine Viallon c4d29baf32 server : fix divide-by-zero in metrics reporting (#11915) 2025-02-17 11:25:12 +01:00
Rémy O 2eea03d86a vulkan: implement several ops relevant for ggml_opt (#11769)
* vulkan: support memset_tensor

* vulkan: support GGML_OP_SUM

* vulkan: implement GGML_OP_ARGMAX

* vulkan: implement GGML_OP_SUB

* vulkan: implement GGML_OP_COUNT_EQUAL

* vulkan: implement GGML_OP_OPT_STEP_ADAMW

* vulkan: fix check_results RWKV_WKV6 crash and memory leaks

* vulkan: implement GGML_OP_REPEAT_BACK

* tests: remove invalid test-backend-ops REPEAT_BACK tests

* vulkan: fix COUNT_EQUAL memset using a fillBuffer command
2025-02-17 07:55:57 +01:00
Xuan-Son Nguyen 0f2bbe6564 server : bump httplib to 0.19.0 (#11908) 2025-02-16 17:11:22 +00:00
standby24x7 fe163d5bf3 common : Fix a typo in help (#11899)
This patch fixes a typo in command help.
prefx -> prefix

Signed-off-by: Masanari Iida <standby24x7@gmail.com>
2025-02-16 10:51:13 +01:00
Xuan-Son Nguyen 818a340ea8 ci : fix (again) arm64 build fails (#11895)
* docker : attempt fixing arm64 build on ci

* qemu v7.0.0-28
2025-02-16 10:36:39 +01:00
Jeff Bolz bf42a23d0a vulkan: support multi/vision rope, and noncontiguous rope (#11902) 2025-02-16 08:52:23 +01:00
Hale Chan c2ea16f260 metal : fix the crash caused by the lack of residency set support on Intel Macs. (#11904) 2025-02-16 08:50:26 +02:00
Johannes Gäßler 6dde178248 scripts: fix compare-llama-bench commit hash logic (#11891) 2025-02-15 20:23:22 +01:00
708-145 fc10c38ded examples: fix typo in imatrix/README.md (#11884)
* simple typo fixed

* Update examples/imatrix/README.md

---------

Co-authored-by: Tobias Bergmann <tobias.bergmann@gmx.de>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-02-15 21:03:30 +02:00
57 changed files with 3743 additions and 2525 deletions
+4
View File
@@ -374,6 +374,8 @@ jobs:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
@@ -1373,8 +1375,10 @@ jobs:
needs:
- ubuntu-cpu-cmake
- ubuntu-22-cmake-vulkan
- windows-latest-cmake
- windows-2019-cmake-cuda
- windows-latest-cmake-sycl
- windows-latest-cmake-hip-release
- macOS-latest-cmake-arm64
- macOS-latest-cmake-x64
+2
View File
@@ -51,6 +51,8 @@ jobs:
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
with:
image: tonistiigi/binfmt:qemu-v7.0.0-28
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
+1
View File
@@ -98,6 +98,7 @@ examples/server/*.css.hpp
examples/server/*.html.hpp
examples/server/*.js.hpp
examples/server/*.mjs.hpp
examples/server/*.gz.hpp
!build_64.sh
!examples/*.bat
!examples/*/*.kts
+1 -1
View File
@@ -1364,7 +1364,7 @@ llama-server: \
examples/server/index.html.hpp \
examples/server/loading.html.hpp \
common/chat.cpp \
common/chat.hpp \
common/chat.h \
common/chat-template.hpp \
common/json.hpp \
common/minja.hpp \
+3 -3
View File
@@ -57,8 +57,7 @@ add_library(${TARGET} STATIC
arg.h
base64.hpp
chat.cpp
chat.hpp
chat-template.hpp
chat.h
common.cpp
common.h
console.cpp
@@ -68,7 +67,8 @@ add_library(${TARGET} STATIC
llguidance.cpp
log.cpp
log.h
minja.hpp
minja/chat-template.hpp
minja/minja.hpp
ngram-cache.cpp
ngram-cache.h
sampling.cpp
+2 -1
View File
@@ -2,6 +2,7 @@
#include "log.h"
#include "sampling.h"
#include "chat.h"
#include <algorithm>
#include <climits>
@@ -2247,7 +2248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("LLAMA_LOG_VERBOSITY"));
add_opt(common_arg(
{"--log-prefix"},
"Enable prefx in log messages",
"Enable prefix in log messages",
[](common_params &) {
common_log_set_prefix(common_log_main(), true);
}
+623 -107
View File
@@ -1,8 +1,433 @@
#include "chat.hpp"
#include "chat-template.hpp"
#include "chat.h"
#include "json-schema-to-grammar.h"
#include "log.h"
#include "minja.hpp"
#include "minja/chat-template.hpp"
#include "minja/minja.hpp"
#include <optional>
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
struct templates_params {
json messages;
json tools;
common_chat_tool_choice tool_choice;
json json_schema;
bool parallel_tool_calls;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
};
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
if (tool_choice == "auto") {
return COMMON_CHAT_TOOL_CHOICE_AUTO;
}
if (tool_choice == "none") {
return COMMON_CHAT_TOOL_CHOICE_NONE;
}
if (tool_choice == "required") {
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
}
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
}
template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
std::vector<common_chat_msg> msgs;
try {
if (!messages.is_array()) {
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
}
for (const auto & message : messages) {
if (!message.is_object()) {
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
}
common_chat_msg msg;
if (!message.contains("role")) {
throw std::runtime_error("Missing 'role' in message: " + message.dump());
}
msg.role = message.at("role");
if (message.contains("content")) {
const auto & content = message.at("content");
if (content.is_string()) {
msg.content = content;
} else if (content.is_array()) {
for (const auto & part : content) {
if (!part.contains("type")) {
throw std::runtime_error("Missing content part type: " + part.dump());
}
const auto & type = part.at("type");
if (type != "text") {
throw std::runtime_error("Unsupported content part type: " + type.dump());
}
common_chat_msg_content_part msg_part;
msg_part.type = type;
msg_part.text = part.at("text");
msg.content_parts.push_back(msg_part);
}
} else if (!content.is_null()) {
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
if (message.contains("reasoning_content")) {
msg.reasoning_content = message.at("reasoning_content");
}
if (message.contains("name")) {
msg.tool_name = message.at("name");
}
if (message.contains("tool_call_id")) {
msg.tool_call_id = message.at("tool_call_id");
}
if (message.contains("tool_calls")) {
for (const auto & tool_call : message.at("tool_calls")) {
common_chat_tool_call tc;
if (!tool_call.contains("type")) {
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
}
const auto & type = tool_call.at("type");
if (type != "function") {
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
}
if (!tool_call.contains("function")) {
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
}
const auto & fc = tool_call.at("function");
if (!fc.contains("name")) {
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
}
tc.name = fc.at("name");
tc.arguments = fc.at("arguments");
if (tool_call.contains("id")) {
tc.id = tool_call.at("id");
}
msg.tool_calls.push_back(tc);
}
}
msgs.push_back(msg);
}
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
}
return msgs;
}
template <>
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
json messages = json::array();
for (const auto & msg : msgs) {
if (!msg.content.empty() && !msg.content_parts.empty()) {
throw std::runtime_error("Cannot specify both content and content_parts");
}
json jmsg {
{"role", msg.role},
};
if (!msg.content.empty()) {
jmsg["content"] = msg.content;
} else if (!msg.content_parts.empty()) {
if (concat_typed_text) {
std::string text;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
continue;
}
if (!text.empty()) {
text += '\n';
}
text += part.text;
}
jmsg["content"] = text;
} else {
auto & parts = jmsg["content"] = json::array();
for (const auto & part : msg.content_parts) {
parts.push_back({
{"type", part.type},
{"text", part.text},
});
}
}
} else {
jmsg["content"] = json(); // null
}
if (!msg.reasoning_content.empty()) {
jmsg["reasoning_content"] = msg.reasoning_content;
}
if (!msg.tool_name.empty()) {
jmsg["name"] = msg.tool_name;
}
if (!msg.tool_call_id.empty()) {
jmsg["tool_call_id"] = msg.tool_call_id;
}
if (!msg.tool_calls.empty()) {
auto & tool_calls = jmsg["tool_calls"] = json::array();
for (const auto & tool_call : msg.tool_calls) {
json tc {
{"type", "function"},
{"function", {
{"name", tool_call.name},
{"arguments", tool_call.arguments},
}},
};
if (!tool_call.id.empty()) {
tc["id"] = tool_call.id;
}
tool_calls.push_back(tc);
}
}
messages.push_back(jmsg);
}
return messages;
}
template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
return common_chat_msgs_parse_oaicompat(json::parse(messages));
}
template <>
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
std::vector<common_chat_tool> result;
try {
if (!tools.is_null()) {
if (!tools.is_array()) {
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
}
for (const auto & tool : tools) {
if (!tool.contains("type")) {
throw std::runtime_error("Missing tool type: " + tool.dump());
}
const auto & type = tool.at("type");
if (!type.is_string() || type != "function") {
throw std::runtime_error("Unsupported tool type: " + tool.dump());
}
if (!tool.contains("function")) {
throw std::runtime_error("Missing tool function: " + tool.dump());
}
const auto & function = tool.at("function");
result.push_back({
/* .name = */ function.at("name"),
/* .description = */ function.at("description"),
/* .parameters = */ function.at("parameters").dump(),
});
}
}
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
}
return result;
}
template <>
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
return common_chat_tools_parse_oaicompat(json::parse(tools));
}
template <>
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
if (tools.empty()) {
return json();
}
auto result = json::array();
for (const auto & tool : tools) {
result.push_back({
{"type", "function"},
{"function", {
{"name", tool.name},
{"description", tool.description},
{"parameters", json::parse(tool.parameters)},
}},
});
}
return result;
}
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
common_chat_msg msg;
msg.role = "user";
msg.content = "test";
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
common_chat_templates_inputs inputs;
inputs.messages = {msg};
common_chat_templates_apply(tmpls.get(), inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
std::string fmt_past_msg;
if (!past_msg.empty()) {
inputs.messages = past_msg;
inputs.add_generation_prompt = false;
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
}
std::ostringstream ss;
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
inputs.messages.push_back(new_msg);
inputs.add_generation_prompt = add_ass;
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
auto add_simple_msg = [&](auto role, auto content) {
common_chat_msg msg;
msg.role = role;
msg.content = content;
inputs.messages.push_back(msg);
};
add_simple_msg("system", "You are a helpful assistant");
add_simple_msg("user", "Hello");
add_simple_msg("assistant", "Hi there");
add_simple_msg("user", "How are you?");
return common_chat_templates_apply(tmpls, inputs).prompt;
}
#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"
void common_chat_templates_free(struct common_chat_templates * tmpls) {
delete tmpls;
}
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
return tmpls->has_explicit_template;
}
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
if (variant != nullptr) {
if (strcmp(variant, "tool_use") == 0) {
if (tmpls->template_tool_use) {
return tmpls->template_tool_use->source().c_str();
}
return nullptr;
} else {
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
}
}
return tmpls->template_default->source().c_str();
}
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override,
const std::string & eos_token_override)
{
std::string default_template_src;
std::string template_tool_use_src;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
GGML_ASSERT(model != nullptr);
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = CHATML_TEMPLATE_SRC;
}
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
if (model) {
const auto * vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
}
return std::string();
}
return common_token_to_piece(vocab, token, true);
};
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
}
common_chat_templates_ptr tmpls(new common_chat_templates());
tmpls->has_explicit_template = has_explicit_template;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
}
if (!template_tool_use_src.empty()) {
try {
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
}
}
return tmpls;
}
std::string common_chat_format_name(common_chat_format format) {
switch (format) {
@@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
this->position = position - 1;
this->found_error = true;
return false;
}
bool null() override { return true; }
bool boolean(bool) override { return true; }
bool number_integer(number_integer_t) override { return true; }
bool number_unsigned(number_unsigned_t) override { return true; }
bool number_float(number_float_t, const string_t &) override { return true; }
bool string(string_t &) override { return true; }
bool binary(binary_t &) override { return true; }
bool start_object(std::size_t) override { return true; }
bool key(string_t &) override { return true; }
bool null() override { return true; } // NOLINT
bool boolean(bool) override { return true; } // NOLINT
bool number_integer(number_integer_t) override { return true; } // NOLINT
bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
bool string(string_t &) override { return true; } // NOLINT
bool binary(binary_t &) override { return true; } // NOLINT
bool start_object(std::size_t) override { return true; } // NOLINT
bool key(string_t &) override { return true; } // NOLINT
bool end_object() override { return true; }
bool start_array(std::size_t) override { return true; }
bool start_array(std::size_t) override { return true; } // NOLINT
bool end_array() override { return true; }
};
json_error_locator err_loc;
@@ -187,13 +612,20 @@ static std::string apply(
// tmpl_inputs.now = std::chrono::system_clock::now();
minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false;
tmpl_opts.use_eos_token = false;
return tmpl.apply(tmpl_inputs, tmpl_opts);
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
if (string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());
}
return result;
}
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto tool_call_schemas = json::array();
@@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
{"required", json::array({"tool_call"})},
};
const auto schema =
inputs.tool_choice != "required"
inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
? json {
{"anyOf", json::array({
tool_call,
@@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
return result;
}
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
@@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
}
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
@@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
const auto & parameters_required = parameters.at("required");
for (const auto & prop : expected_properties) {
if (!parameters_properties.contains(prop)) {
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
}
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
}
}
if (parameters_properties.size() != expected_properties.size()) {
@@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
}
}
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array();
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha") {
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") {
@@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
@@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
auto arg_value_str = raw_args.substr(it_eq + 1);
auto arg_value = json::parse(arg_value_str);
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ match[1],
/* .arguments = */ (json {
{arg_name, arg_value},
}).dump(),
/* .id = */ "",
},
},
};
common_chat_msg msg;
msg.role = "assistant";
msg.content = match.prefix().str();
msg.tool_calls.push_back({
/* .name = */ name,
/* .arguments = */ (json {
{arg_name, arg_value},
}).dump(),
/* .id = */ "",
});
return msg;
}
}
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
}
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n"
@@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input,
return msg;
}
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
fprintf(stderr, "%s\n", __func__);
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
});
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
@@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
}
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
@@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
@@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
}
}
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
@@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
throw std::runtime_error("Missing type in python tool");
}
has_raw_python = true;
auto type = parameters.at("type");
const auto & type = parameters.at("type");
if (type == "object") {
auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
@@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) {
auto code = match[1].str();
return {
/* .role = */ "assistant",
/* .content = */ match.prefix().str(),
/* .tool_calls = */ {
{
/* .name = */ "python",
/* .arguments = */ (json {{"code", code}}).dump(),
/* .id = */ "",
},
}
};
common_chat_msg msg;
msg.role = "assistant";
msg.content = match.prefix().str();
msg.tool_calls.push_back({
/* .name = */ "python",
/* .arguments = */ (json {{"code", code}}).dump(),
/* .id = */ "",
});
return msg;
}
static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)");
@@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
}
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = inputs.tool_choice != "required";
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
@@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
common_chat_msg msg;
msg.role = "assistant";
auto end = input.end();
std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
msg.content = input;
return msg;
}
common_chat_msg result;
result.role = "assistant";
result.content = rit->prefix();
msg.content = rit->prefix();
auto it = rit->suffix().first;
while (it != end) {
@@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
throw std::runtime_error("Failed to parse json tool call");
}
const auto & arguments = call.at("arguments");
result.tool_calls.push_back({
msg.tool_calls.push_back({
call.at("name"),
arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
@@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
break;
}
}
return result;
return msg;
} catch (const std::exception & e) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
common_chat_msg msg;
msg.role = "assistant";
msg.content = input;
return msg;
}
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
return data;
}
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
static common_chat_params common_chat_templates_apply_jinja(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
templates_params params;
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
? *tmpls->template_tool_use
: *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
params.add_generation_prompt = inputs.add_generation_prompt;
params.extract_reasoning = inputs.extract_reasoning;
params.tool_choice = inputs.tool_choice;
params.grammar = inputs.grammar;
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
if (inputs.tools.is_array()) {
if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
params.parallel_tool_calls = false;
} else {
params.parallel_tool_calls = inputs.parallel_tool_calls;
}
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools");
}
if (caps.supports_tool_calls && !caps.supports_tools) {
@@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
}
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
if (src.find("<tool▁calls▁begin>") != std::string::npos && inputs.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, inputs);
if (src.find("<tool▁calls▁begin>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, params);
}
// Command R7B: : use handler in all cases except json schema (thinking / tools).
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
return common_chat_params_init_command_r7b(tmpl, inputs);
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_command_r7b(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
return common_chat_params_init_generic(tmpl, inputs);
if ((params.tools.is_array() && params.json_schema.is_object())) {
return common_chat_params_init_generic(tmpl, params);
}
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
if (src.find(">>>all") != std::string::npos) {
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
return common_chat_params_init_functionary_v3_2(tmpl, params);
}
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
if (src.find(" functools[") != std::string::npos) {
return common_chat_params_init_firefunction_v2(tmpl, inputs);
return common_chat_params_init_firefunction_v2(tmpl, params);
}
// Plain handler (no tools)
if (inputs.tools.is_null() || inputs.tool_choice == "none") {
return common_chat_params_init_without_tools(tmpl, inputs);
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
}
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos) {
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
return common_chat_params_init_hermes_2_pro(tmpl, params);
}
// Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
}
// Llama 3.1, 3.2, 3.3 (w/ tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
}
// Mistral Nemo (w/ tools)
if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo(tmpl, inputs);
return common_chat_params_init_mistral_nemo(tmpl, params);
}
// Generic fallback
return common_chat_params_init_generic(tmpl, inputs);
return common_chat_params_init_generic(tmpl, params);
}
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
static common_chat_params common_chat_templates_apply_legacy(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
int alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
continue;
}
if (!content.empty()) {
content += "\n";;
}
content += part.text;
}
contents.emplace_back(std::move(content));
}
for (size_t i = 0; i < contents.size(); ++i) {
const auto & msg = inputs.messages[i];
const auto & content = contents[i];
chat.push_back({msg.role.c_str(), content.c_str()});
alloc_size += (msg.role.size() + content.size()) * 1.25;
}
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
const auto & src = tmpls->template_default->source();
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
}
common_chat_params params;
params.prompt = std::string(buf.data(), res);
if (!inputs.json_schema.empty()) {
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
} else {
params.grammar = inputs.grammar;
}
return params;
}
common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
GGML_ASSERT(tmpls != nullptr);
return inputs.use_jinja
? common_chat_templates_apply_jinja(tmpls, inputs)
: common_chat_templates_apply_legacy(tmpls, inputs);
}
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
return {
/* .role = */ "assistant",
/* .content = */ input,
/* .tool_calls = */ {},
};
common_chat_msg msg;
msg.role = "assistant";
msg.content = input;
return msg;
}
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
+134
View File
@@ -0,0 +1,134 @@
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
#pragma once
#include "common.h"
#include <string>
#include <vector>
struct common_chat_templates;
struct common_chat_tool_call {
std::string name;
std::string arguments;
std::string id;
};
struct common_chat_msg_content_part {
std::string type;
std::string text;
};
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_chat_msg_content_part> content_parts = {};
std::vector<common_chat_tool_call> tool_calls = {};
std::string reasoning_content;
std::string tool_name;
std::string tool_call_id;
};
struct common_chat_tool {
std::string name;
std::string description;
std::string parameters;
};
enum common_chat_tool_choice {
COMMON_CHAT_TOOL_CHOICE_AUTO,
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
COMMON_CHAT_TOOL_CHOICE_NONE,
};
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_templates_inputs {
std::vector<common_chat_msg> messages;
std::string grammar;
std::string json_schema;
bool add_generation_prompt = true;
bool use_jinja = true;
// Parameters below only supported when use_jinja is true
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
bool extract_reasoning = true;
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
void common_chat_templates_free(struct common_chat_templates * tmpls);
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
struct common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const struct common_chat_templates * tmpls,
bool use_jinja);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
// Parses a JSON array of messages in OpenAI's chat completion API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
-55
View File
@@ -1,55 +0,0 @@
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
#pragma once
#include "common.h"
#include <json.hpp>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
struct common_chat_inputs {
json messages;
json tools;
json tool_choice;
json json_schema;
bool parallel_tool_calls;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
};
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
json prompt;
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
-170
View File
@@ -12,8 +12,6 @@
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "chat.hpp"
#include "chat-template.hpp"
#include <algorithm>
#include <cinttypes>
@@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
return text;
}
//
// Chat template utils
//
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
common_chat_params_init(chat_template, inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass,
bool use_jinja) {
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
common_chat_inputs inputs;
inputs.messages = messages;
inputs.add_generation_prompt = add_ass;
return common_chat_params_init(tmpl, inputs).prompt;
}
int alloc_size = 0;
std::vector<llama_chat_message> chat;
for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant", {}},
{"user", "Hello", {}},
{"assistant", "Hi there", {}},
{"user", "How are you?", {}},
};
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
}
#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
std::string default_template_src;
std::string template_tool_use_src;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
auto str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = CHATML_TEMPLATE_SRC;
}
}
auto vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
}
return std::string();
} else {
return common_token_to_piece(vocab, token, true);
}
};
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
try {
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
};
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
return {
has_explicit_template,
std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
nullptr,
};
}
}
//
// KV cache utils
//
+2 -58
View File
@@ -178,10 +178,10 @@ struct common_params_speculative {
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
@@ -616,62 +616,6 @@ std::string common_detokenize(
const std::vector<llama_token> & tokens,
bool special = true);
//
// Chat template utils
//
struct common_tool_call {
std::string name;
std::string arguments;
std::string id;
};
// same with llama_chat_message, but uses std::string
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_tool_call> tool_calls;
std::string reasoning_content = "";
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
namespace minja {
class chat_template;
}
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & chat,
bool add_ass,
bool use_jinja);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const common_chat_template & tmpl, bool use_jinja);
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
//
// KV cache utils
//
+5 -5
View File
@@ -252,11 +252,6 @@ llama_tokens common_speculative_gen_draft(
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
common_sampler_accept(smpl, id, true);
result.push_back(id);
@@ -265,6 +260,11 @@ llama_tokens common_speculative_gen_draft(
break;
}
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model
+1 -1
View File
@@ -9,7 +9,7 @@ struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
float p_min = 0.9f; // min probability required to accept a token in the draft
float p_min = 0.75f; // min probability required to accept a token in the draft
};
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
+1 -1
View File
@@ -1,6 +1,6 @@
# llama.cpp/examples/imatrix
Compute an importance matrix for a model and given text dataset. Can be used during quantization to enchance the quality of the quantized models.
Compute an importance matrix for a model and given text dataset. Can be used during quantization to enhance the quality of the quantized models.
More information is available here: https://github.com/ggml-org/llama.cpp/pull/4861
## Usage
+16 -11
View File
@@ -4,7 +4,7 @@
#include "log.h"
#include "sampling.h"
#include "llama.h"
#include "chat-template.hpp"
#include "chat.h"
#include <cstdio>
#include <cstring>
@@ -158,7 +158,7 @@ int main(int argc, char ** argv) {
}
const llama_vocab * vocab = llama_model_get_vocab(model);
auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
auto chat_templates = common_chat_templates_init(model, params.chat_template);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@@ -201,7 +201,7 @@ int main(int argc, char ** argv) {
}
// auto enable conversation mode if chat template is available
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
if (has_chat_template) {
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode
if (params.conversation_mode) {
if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
} else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
}
@@ -264,9 +264,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content, {}};
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back({role, content, {}});
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back(new_msg);
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
};
@@ -755,11 +757,14 @@ int main(int argc, char ** argv) {
// check for reverse prompt using special tokens
llama_token last_token = common_sampler_last(smpl);
if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) {
if (params.interactive) {
is_interacting = true;
for (auto token : antiprompt_token) {
if (token == last_token) {
if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
break;
}
is_antiprompt = true;
}
if (is_antiprompt) {
+22 -44
View File
@@ -24,7 +24,7 @@
#include <string>
#include <vector>
#include "chat-template.hpp"
#include "chat.h"
#include "common.h"
#include "json.hpp"
#include "linenoise.cpp/linenoise.h"
@@ -557,7 +557,7 @@ class LlamaData {
llama_model_ptr model;
llama_sampler_ptr sampler;
llama_context_ptr context;
std::vector<llama_chat_message> messages;
std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
std::list<std::string> msg_strs;
std::vector<char> fmtted;
@@ -834,44 +834,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
}
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) {
json messages = json::array();
for (const auto & msg : llama_data.messages) {
messages.push_back({
{"role", msg.role},
{"content", msg.content},
});
}
try {
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages;
tmpl_inputs.add_generation_prompt = append;
minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false;
tmpl_opts.use_eos_token = false;
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
return -1;
}
}
int result = llama_chat_apply_template(
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
common_chat_templates_inputs inputs;
for (const auto & msg : llama_data.messages) {
common_chat_msg cmsg;
cmsg.role = msg.role;
cmsg.content = msg.content;
inputs.messages.push_back(cmsg);
}
inputs.add_generation_prompt = append;
inputs.use_jinja = use_jinja;
return result;
auto chat_params = common_chat_templates_apply(tmpls, inputs);
// TODO: use other params for tool calls.
auto result = chat_params.prompt;
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
}
// Function to tokenize the prompt
@@ -1015,8 +994,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
}
// Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
if (new_len < 0) {
printe("failed to apply the chat template\n");
return -1;
@@ -1078,8 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) {
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
GGML_ASSERT(chat_templates.template_default);
auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
@@ -1090,7 +1068,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
add_message("user", user.empty() ? user_input : user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
return 1;
}
@@ -1105,7 +1083,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
}
add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
return 1;
}
}
+379 -290
View File
File diff suppressed because it is too large Load Diff
Binary file not shown.
+32 -50
View File
@@ -274,7 +274,7 @@ struct server_task {
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
params.speculative.n_min = std::max(params.speculative.n_min, 2);
params.speculative.n_min = std::max(params.speculative.n_min, 0);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
// Use OpenAI API logprobs only if n_probs wasn't provided
@@ -329,9 +329,6 @@ struct server_task {
}
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
@@ -1807,7 +1804,7 @@ struct server_context {
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
common_chat_templates chat_templates;
common_chat_templates_ptr chat_templates;
~server_context() {
// Clear any sampling context
@@ -1891,45 +1888,17 @@ struct server_context {
llama_init_dft.context.reset();
}
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja);
} catch (const std::exception & e) {
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_from_model(model, "chatml");
} else {
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
chat_templates = common_chat_templates_init(model, "chatml");
}
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
return true;
}
bool validate_builtin_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}};
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default);
try {
common_chat_params_init(*templates.template_default, inputs);
if (templates.template_tool_use) {
common_chat_params_init(*templates.template_tool_use, inputs);
}
return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());
return false;
}
} else {
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
}
}
void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
@@ -3656,7 +3625,7 @@ int main(int argc, char ** argv) {
}, {
{"name", "n_busy_slots_per_decode"},
{"help", "Average number of busy slots per llama_decode() call"},
{"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total}
{"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)}
}}},
{"gauge", {{
{"name", "prompt_tokens_seconds"},
@@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model },
{ "chat_template", ctx_server.chat_templates.template_default->source() },
{ "bos_token", ctx_server.chat_templates.template_default->bos_token() },
{ "eos_token", ctx_server.chat_templates.template_default->eos_token() },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
{ "build_info", build_info },
};
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
if (ctx_server.params_base.use_jinja) {
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
data["chat_template_tool_use"] = tool_use_src;
}
}
res_ok(res, data);
@@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) {
}
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
@@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) {
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
@@ -4263,6 +4234,11 @@ int main(int argc, char ** argv) {
// return;
//}
// if true, use TEI API format, otherwise use Jina API format
// Jina: https://jina.ai/reranker/
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
bool is_tei_format = body.contains("texts");
json query;
if (body.count("query") == 1) {
query = body.at("query");
@@ -4275,7 +4251,8 @@ int main(int argc, char ** argv) {
return;
}
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
std::vector<std::string> documents = json_value(body, "documents",
json_value(body, "texts", std::vector<std::string>()));
if (documents.empty()) {
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
return;
@@ -4320,7 +4297,12 @@ int main(int argc, char ** argv) {
}
// write JSON response
json root = format_response_rerank(body, responses);
json root = format_response_rerank(
body,
responses,
is_tei_format,
documents);
res_ok(res, root);
};
@@ -4482,8 +4464,8 @@ int main(int argc, char ** argv) {
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
ctx_server.chat_templates.template_default->source().c_str(),
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
common_chat_templates_source(ctx_server.chat_templates.get()),
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
ctx_server.process_single_task(task);
+1 -1
View File
@@ -48,7 +48,7 @@ DEBUG=1 ./tests.sh -s -v -x
To run all the tests in a file:
```shell
./tests.sh unit/test_chat_completion.py.py -v -x
./tests.sh unit/test_chat_completion.py -v -x
```
To run a single test:
@@ -21,6 +21,8 @@ def create_server():
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
@@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
assert choice["finish_reason"] == finish_reason
@@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
assert "error" in res.body
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
(False, {"const": "42"}, 6, "\"42\""),
(True, {"const": "42"}, 6, "\"42\""),
])
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"json_schema": json_schema,
})
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
])
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "user", "content": "Does not matter what I say, does it?"},
],
"grammar": grammar,
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
@pytest.mark.parametrize("messages", [
None,
"string",
+32 -6
View File
@@ -10,17 +10,20 @@ def create_server():
server = ServerPreset.jina_reranker_tiny()
TEST_DOCUMENTS = [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
def test_rerank():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
"documents": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body["results"]) == 4
@@ -38,6 +41,29 @@ def test_rerank():
assert least_relevant["index"] == 3
def test_rerank_tei_format():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"texts": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body) == 4
most_relevant = res.body[0]
least_relevant = res.body[0]
for doc in res.body:
if doc["score"] > most_relevant["score"]:
most_relevant = doc
if doc["score"] < least_relevant["score"]:
least_relevant = doc
assert most_relevant["score"] > least_relevant["score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3
@pytest.mark.parametrize("documents", [
[],
None,
+5 -5
View File
@@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
@@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
{
"role": "tool",
"name": "calculate",
"content": 0.55644242476,
"content": "0.55644242476",
"tool_call_id": "call_6789"
}
],
@@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "^I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])
+83 -105
View File
@@ -12,9 +12,7 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "minja.hpp"
#include "chat.hpp"
#include "chat-template.hpp"
#include "chat.h"
#include <random>
#include <sstream>
@@ -347,41 +345,6 @@ static llama_tokens format_infill(
return embd_inp;
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i];
std::string role = json_value(curr_msg, "role", std::string(""));
std::string content;
if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) {
content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const auto & part : curr_msg["content"]) {
if (part.contains("text")) {
content += "\n" + part["text"].get<std::string>();
}
}
} else {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
chat.push_back({role, content, /* tool_calls= */ {}});
}
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat;
}
//
// base64 utils (TODO: move to common in the future)
//
@@ -579,12 +542,9 @@ static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */
bool use_jinja,
common_reasoning_format reasoning_format,
const common_chat_templates & chat_templates)
const struct common_chat_templates * tmpls)
{
json llama_params;
const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
? *chat_templates.template_tool_use
: *chat_templates.template_default;
auto tools = json_value(body, "tools", json());
auto stream = json_value(body, "stream", false);
@@ -610,62 +570,58 @@ static json oaicompat_completion_params_parse(
llama_params["stop"] = json_value(body, "stop", json::array());
}
auto json_schema = json_value(body, "json_schema", json());
auto grammar = json_value(body, "grammar", std::string());
if (!json_schema.is_null() && !grammar.empty()) {
throw std::runtime_error("Cannot use both json_schema and grammar");
}
// Handle "response_format" field
if (body.contains("response_format")) {
json response_format = json_value(body, "response_format", json::object());
std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") {
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
json_schema = json_value(response_format, "schema", json::object());
} else if (response_type == "json_schema") {
json json_schema = json_value(response_format, "json_schema", json::object());
llama_params["json_schema"] = json_value(json_schema, "schema", json::object());
json_schema = json_value(json_schema, "schema", json::object());
} else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
}
}
// Apply chat template to the list of messages
if (use_jinja) {
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
}
if (tool_choice != "none" && llama_params.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
common_chat_inputs inputs;
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.messages = body.at("messages");
inputs.tools = tools;
inputs.tool_choice = tool_choice;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
inputs.parallel_tool_calls = false;
}
inputs.stream = stream;
// TODO: support mixing schema w/ tools beyond generic format.
inputs.json_schema = json_value(llama_params, "json_schema", json());
auto chat_params = common_chat_params_init(tmpl, inputs);
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
inputs.tools = common_chat_tools_parse_oaicompat(tools);
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.add_generation_prompt = true;
inputs.use_jinja = use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
llama_params["chat_format"] = static_cast<int>(chat_params.format);
llama_params["prompt"] = chat_params.prompt;
llama_params["grammar"] = chat_params.grammar;
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) {
grammar_triggers.push_back({
{"word", trigger.word},
{"at_start", trigger.at_start},
});
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
} else {
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(tmpls, inputs);
llama_params["chat_format"] = static_cast<int>(chat_params.format);
llama_params["prompt"] = chat_params.prompt;
llama_params["grammar"] = chat_params.grammar;
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) {
grammar_triggers.push_back({
{"word", trigger.word},
{"at_start", trigger.at_start},
});
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
// Handle "n" field
@@ -737,29 +693,51 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
return res;
}
static json format_response_rerank(const json & request, const json & ranks) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & rank : ranks) {
data.push_back(json{
{"index", i++},
{"relevance_score", json_value(rank, "score", 0.0)},
});
static json format_response_rerank(
const json & request,
const json & ranks,
bool is_tei_format,
std::vector<std::string> & texts) {
json res;
if (is_tei_format) {
// TEI response format
res = json::array();
bool return_text = json_value(request, "return_text", false);
for (const auto & rank : ranks) {
int index = json_value(rank, "index", 0);
json elem = json{
{"index", index},
{"score", json_value(rank, "score", 0.0)},
};
if (return_text) {
elem["text"] = std::move(texts[index]);
}
res.push_back(elem);
}
} else {
// Jina response format
json results = json::array();
int32_t n_tokens = 0;
for (const auto & rank : ranks) {
results.push_back(json{
{"index", json_value(rank, "index", 0)},
{"relevance_score", json_value(rank, "score", 0.0)},
});
n_tokens += json_value(rank, "tokens_evaluated", 0);
n_tokens += json_value(rank, "tokens_evaluated", 0);
}
res = json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json{
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", results}
};
}
json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json {
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", data}
};
return res;
}
@@ -159,6 +159,35 @@ export default function ChatMessage({
</div>
</details>
)}
{msg.extra && msg.extra.length > 0 && (
<details
className={classNames({
'collapse collapse-arrow mb-4 bg-base-200': true,
'bg-opacity-10': msg.role !== 'assistant',
})}
>
<summary className="collapse-title">
Extra content
</summary>
<div className="collapse-content">
{msg.extra.map(
(extra, i) =>
extra.type === 'textFile' ? (
<div key={extra.name}>
<b>{extra.name}</b>
<pre>{extra.content}</pre>
</div>
) : extra.type === 'context' ? (
<div key={i}>
<pre>{extra.content}</pre>
</div>
) : null // TODO: support other extra types
)}
</div>
</details>
)}
<MarkdownDisplay
content={content}
isGenerating={isPending}
@@ -1,10 +1,11 @@
import { useEffect, useMemo, useState } from 'react';
import { useEffect, useMemo, useRef, useState } from 'react';
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
import ChatMessage from './ChatMessage';
import { CanvasType, Message, PendingMessage } from '../utils/types';
import { classNames, throttle } from '../utils/misc';
import CanvasPyInterpreter from './CanvasPyInterpreter';
import StorageUtils from '../utils/storage';
import { useVSCodeContext } from '../utils/llama-vscode';
/**
* A message display is a message node with additional information for rendering.
@@ -81,6 +82,14 @@ export default function ChatScreen() {
replaceMessageAndGenerate,
} = useAppContext();
const [inputMsg, setInputMsg] = useState('');
const inputRef = useRef<HTMLTextAreaElement>(null);
const { extraContext, clearExtraContext } = useVSCodeContext(
inputRef,
setInputMsg
);
// TODO: improve this when we have "upload file" feature
const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
// keep track of leaf node for rendering
const [currNodeId, setCurrNodeId] = useState<number>(-1);
@@ -115,10 +124,20 @@ export default function ChatScreen() {
setCurrNodeId(-1);
// get the last message node
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
if (!(await sendMessage(currConvId, lastMsgNodeId, inputMsg, onChunk))) {
if (
!(await sendMessage(
currConvId,
lastMsgNodeId,
inputMsg,
currExtra,
onChunk
))
) {
// restore the input message if failed
setInputMsg(lastInpMsg);
}
// OK
clearExtraContext();
};
const handleEditMessage = async (msg: Message, content: string) => {
@@ -129,6 +148,7 @@ export default function ChatScreen() {
viewingChat.conv.id,
msg.parent,
content,
msg.extra,
onChunk
);
setCurrNodeId(-1);
@@ -143,6 +163,7 @@ export default function ChatScreen() {
viewingChat.conv.id,
msg.parent,
null,
msg.extra,
onChunk
);
setCurrNodeId(-1);
@@ -203,6 +224,7 @@ export default function ChatScreen() {
<textarea
className="textarea textarea-bordered w-full"
placeholder="Type a message (Shift+Enter to add a new line)"
ref={inputRef}
value={inputMsg}
onChange={(e) => setInputMsg(e.target.value)}
onKeyDown={(e) => {
@@ -25,6 +25,7 @@ interface AppContextValue {
convId: string | null,
leafNodeId: Message['id'] | null,
content: string,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
) => Promise<boolean>;
stopGenerating: (convId: string) => void;
@@ -32,6 +33,7 @@ interface AppContextValue {
convId: string,
parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
) => Promise<void>;
@@ -274,6 +276,7 @@ export const AppContextProvider = ({
convId: string | null,
leafNodeId: Message['id'] | null,
content: string,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
): Promise<boolean> => {
if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
@@ -298,6 +301,7 @@ export const AppContextProvider = ({
convId,
role: 'user',
content,
extra,
parent: leafNodeId,
children: [],
},
@@ -324,6 +328,7 @@ export const AppContextProvider = ({
convId: string,
parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
) => {
if (isGenerating(convId)) return;
@@ -339,6 +344,7 @@ export const AppContextProvider = ({
convId,
role: 'user',
content,
extra,
parent: parentNodeId,
children: [],
},
@@ -0,0 +1,62 @@
import { useEffect, useState } from 'react';
import { MessageExtraContext } from './types';
// Extra context when using llama.cpp WebUI from llama-vscode, inside an iframe
// Ref: https://github.com/ggml-org/llama.cpp/pull/11940
interface SetTextEvData {
text: string;
context: string;
}
/**
* To test it:
* window.postMessage({ command: 'setText', text: 'Spot the syntax error', context: 'def test()\n return 123' }, '*');
*/
export const useVSCodeContext = (
inputRef: React.RefObject<HTMLTextAreaElement>,
setInputMsg: (text: string) => void
) => {
const [extraContext, setExtraContext] = useState<MessageExtraContext | null>(
null
);
// Accept setText message from a parent window and set inputMsg and extraContext
useEffect(() => {
const handleMessage = (event: MessageEvent) => {
if (event.data?.command === 'setText') {
const data: SetTextEvData = event.data;
setInputMsg(data?.text);
if (data?.context && data.context.length > 0) {
setExtraContext({
type: 'context',
content: data.context,
});
}
inputRef.current?.focus();
}
};
window.addEventListener('message', handleMessage);
return () => window.removeEventListener('message', handleMessage);
}, []);
// Add a keydown listener that sends the "escapePressed" message to the parent window
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
window.parent.postMessage({ command: 'escapePressed' }, '*');
}
};
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, []);
return {
extraContext,
// call once the user message is sent, to clear the extra context
clearExtraContext: () => setExtraContext(null),
};
};
+12 -1
View File
@@ -53,12 +53,23 @@ export const copyStr = (textToCopy: string) => {
/**
* filter out redundant fields upon sending to API
* also format extra into text
*/
export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
return messages.map((msg) => {
let newContent = '';
for (const extra of msg.extra ?? []) {
if (extra.type === 'context') {
newContent += `${extra.content}\n\n`;
}
}
newContent += msg.content;
return {
role: msg.role,
content: msg.content,
content: newContent,
};
}) as APIMessage[];
}
+14
View File
@@ -42,11 +42,25 @@ export interface Message {
role: 'user' | 'assistant' | 'system';
content: string;
timings?: TimingReport;
extra?: MessageExtra[];
// node based system for branching
parent: Message['id'];
children: Message['id'][];
}
type MessageExtra = MessageExtraTextFile | MessageExtraContext; // TODO: will add more in the future
export interface MessageExtraTextFile {
type: 'textFile';
name: string;
content: string;
}
export interface MessageExtraContext {
type: 'context';
content: string;
}
export type APIMessage = Pick<Message, 'role' | 'content'>;
export interface Conversation {
+15 -6
View File
@@ -41,12 +41,13 @@
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
// GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
@@ -199,6 +200,10 @@ typedef float2 dfloat2;
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
static bool cp_async_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
+46
View File
@@ -0,0 +1,46 @@
// Simplified API for asynchronous data loading.
#include "common.cuh"
// Copies data from global to shared memory, cg == cache global.
// Both the src and dst pointers must be aligned to 16 bit.
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
template <int preload>
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
#ifdef CP_ASYNC_AVAILABLE
#if CUDART_VERSION >= 11040
if (preload == 256) {
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 128) {
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 64) {
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else
#endif // CUDART_VERSION >= 11040
{
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
: : "r"(dst), "l"(src));
}
#else
GGML_UNUSED(dst);
GGML_UNUSED(src);
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}
// Makes each thread wait until its asynchronous data copies are done.
// This does NOT provide any additional synchronization.
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
static __device__ __forceinline__ void cp_async_wait_all() {
#ifdef CP_ASYNC_AVAILABLE
asm volatile("cp.async.wait_all;");
#else
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}
+9 -6
View File
@@ -716,7 +716,9 @@ void launch_fattn(
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
@@ -768,13 +770,14 @@ void launch_fattn(
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
const bool short_context = K->ne[1] < 4096;
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
const int nblocks_stream_k = 2*nsm;
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;
@@ -827,7 +830,7 @@ void launch_fattn(
CUDA_CHECK(cudaGetLastError());
if constexpr (parallel_blocks == 0) {
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;
+349 -265
View File
@@ -1,7 +1,252 @@
#include "common.cuh"
#include "cp-async.cuh"
#include "mma.cuh"
#include "fattn-common.cuh"
using namespace ggml_cuda_mma;
typedef tile<16, 8, half2> tile_A;
typedef tile< 8, 8, half2> tile_B;
typedef tile<16, 8, float> tile_C_KQ;
typedef tile<16, 4, half2> tile_C_VKQ;
template<int D, int nwarps, int KQ_stride>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
// If cp.async is available, load up to the highest power of 2 in D asynchronously:
#ifdef CP_ASYNC_AVAILABLE
static_assert(D >= 64 && D < 512, "bad D");
constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
constexpr int preload = 64;
constexpr int h2_per_chunk = 16/sizeof(half2);
constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
constexpr int stride_i = WARP_SIZE / chunks_per_row;
#pragma unroll
for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
}
#else
constexpr int k0_sync_start = 0;
#endif // CP_ASYNC_AVAILABLE
static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
// If D is not a power of 2, the rest is loaded synchronously.
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
continue;
}
#pragma unroll
for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
}
}
}
}
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half * const __restrict__ maskh,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
const float slope,
const float logit_softcap,
const int ne01,
const int ne02,
const int stride_Q,
const int stride_KV,
const int stride_mask,
const int jt,
half2 * const __restrict__ tile_K,
half2 * const __restrict__ tile_V,
const tile_B * const __restrict__ Q_B,
tile_C_VKQ * const __restrict__ VKQ_C,
float2 & KQ_max,
float2 & KQ_rowsum,
const int kb0) {
#ifdef NEW_MMA_AVAILABLE
constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
const int k_VKQ_0 = kb0*KQ_stride;
tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)];
#ifdef CP_ASYNC_AVAILABLE
cp_async_wait_all();
__syncthreads();
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
#else
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
__syncthreads();
#endif // CP_ASYNC_AVAILABLE
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) {
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
tile_A K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]);
}
}
#ifndef CP_ASYNC_AVAILABLE
__syncthreads(); // Only needed if tile_K == tile_V.
#endif // CP_ASYNC_AVAILABLE
if (use_logit_softcap) {
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) {
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
}
}
}
if (maskh) {
static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size");
static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size");
#pragma unroll
for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
const int i = i0 + tile_C_KQ::get_i(l);
const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
}
}
}
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
float2 KQ_max_new = KQ_max;
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
#pragma unroll
for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) {
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
}
}
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
#pragma unroll
for (int offset = 16; offset > 2; offset >>= 1) {
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
}
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < tile_C_KQ::ne; ++l) {
const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y;
const float diff = KQ_C[k].x[l] - KQ_max_l;
KQ_C[k].x[l] = expf(diff);
if (l % 2 == 0) {
KQ_rowsum_add.x += KQ_C[k].x[l];
} else {
KQ_rowsum_add.y += KQ_C[k].x[l];
}
}
}
{
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
KQ_max = KQ_max_new;
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
#pragma unroll
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
}
// Convert KQ C tiles into B tiles for VKQ calculation:
tile_B B[KQ_stride/(np*2*tile_B::J)];
static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) {
B[k] = get_transposed(get_half2(KQ_C[k]));
}
#ifdef CP_ASYNC_AVAILABLE
cp_async_wait_all();
__syncthreads();
if (!last_iter) {
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV);
}
#else
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
__syncthreads();
#endif // CP_ASYNC_AVAILABLE
// Calculate VKQ tile:
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size");
#pragma unroll
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) {
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
tile_A A;
load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
}
}
#ifndef CP_ASYNC_AVAILABLE
__syncthreads(); // Only needed if tile_K == tile_V.
#endif // CP_ASYNC_AVAILABLE
#else
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2,
@@ -13,61 +258,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float scale,
const float slope,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3,
const int stride_Q,
const int stride_KV,
const int stride_mask,
const int jt,
const int kb0_start,
const int kb0_stop) {
#ifdef NEW_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
typedef mma_A_I16K8<half2> mma_A;
typedef mma_B_J8K8<half2> mma_B;
typedef mma_C_I16J8<float> mma_C_KQ;
typedef mma_C_I16J8<half2> mma_C_VKQ;
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
static_assert(D % nwarps == 0, "bad D");
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
const int stride_Q = nb01 / sizeof(float2);
const int stride_KV = nb11 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half);
// Temporary shared buffer for loading K/V data with KQ_stride*D logical elements:
extern __shared__ half2 tile_K[];
#ifdef CP_ASYNC_AVAILABLE
half2 * tile_V = tile_K + KQ_stride*D2_padded;
#else
half2 * tile_V = tile_K;
#endif // CP_ASYNC_AVAILABLE
mma_B Q_B[D/(2*mma_B::K)];
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
tile_B Q_B[D/(2*tile_B::J)];
tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
float2 KQ_rowsum = {0.0f, 0.0f};
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
float2 KQ_max_scale = {0.0f, 0.0f};
float2 KQ_rowsum = {0.0f, 0.0f};
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
// Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
// Temporarily load Q data into tile_K, will be loaded into registers afterwards.
// The loading is done with decreasing granularity for D for better memory bandwidth.
const half2 scale_h2 = make_half2(scale, scale);
#pragma unroll
@@ -76,6 +300,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_j = WARP_SIZE / stride_k;
if (k0_start == k0_stop) {
continue;
}
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
break;
}
@@ -90,14 +318,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
}
} else {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f);
}
}
}
@@ -106,198 +334,42 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
{
const int j0 = (threadIdx.y / np) * mma_B::J;
const int j0 = (threadIdx.y / np) * tile_B::I;
#pragma unroll
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
}
}
__syncthreads();
// Preload K data for first iteration when using cp_async:
#ifdef CP_ASYNC_AVAILABLE
flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV);
#endif // CP_ASYNC_AVAILABLE
// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
const int k_VKQ_0 = kb0*KQ_stride;
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
// Load K data into tile with decreasing granularity for D for better memory bandwidth:
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
}
}
}
__syncthreads();
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
mma_A K_A;
K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
}
}
__syncthreads();
if (use_logit_softcap) {
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
}
}
}
if (maskh) {
static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
#pragma unroll
for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
const int i = i0 + mma_C_KQ::get_i(l);
const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
}
}
}
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
float2 KQ_max_new = KQ_max;
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
#pragma unroll
for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
}
}
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
#pragma unroll
for (int offset = 16; offset > 2; offset >>= 1) {
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
}
{
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale.x = 0.0f;
}
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale.y = 0.0f;
}
KQ_max = KQ_max_new;
}
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
const float diff = KQ_C[k].x[l] - KQ_max_l;
KQ_C[k].x[l] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_C[k].x[l] = 0.0f;
}
if (l % 2 == 0) {
KQ_rowsum_add.x += KQ_C[k].x[l];
} else {
KQ_rowsum_add.y += KQ_C[k].x[l];
}
}
}
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
#pragma unroll
for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < mma_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
// Convert KQ C tiles into B tiles for VKQ calculation:
mma_B B[KQ_stride/(np*2*mma_B::K)];
static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
B[k] = KQ_C[k].to_mma_B();
}
// Load V data into tile with decreasing granularity for D for better memory bandwidth:
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
const int i0_stop = D/2 - (D/2) % (1*stride_i);
const int stride_k = WARP_SIZE / stride_i;
#pragma unroll
for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
#pragma unroll
for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
}
}
}
__syncthreads();
// Calculate VKQ tile:
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
#pragma unroll
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
mma_A A;
A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
}
}
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
// With cp_async there is no __syncthreads at the end of the iter,
// there can be a race condition on shared memory access for combining/writing back results.
#ifdef CP_ASYNC_AVAILABLE
if (nwarps*tile_B::I > KQ_stride) {
__syncthreads();
}
#endif // CP_ASYNC_AVAILABLE
// Finally, sum up partial KQ rowsums.
// The partial sums are spread across 8 threads each, does not need full reduce.
@@ -310,26 +382,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// Write VKQ accumulators to shared memory in column-major format.
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// Also for np > 1 the combination is done via these values in shared memory.
const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data
#pragma unroll
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int k = k0 + mma_B::get_k(l);
for (int l = 0; l < tile_B::ne; ++l) {
const int k = k0 + tile_B::get_j(l);
tile_KV[j_cwd*D2_padded + k] = B.x[l];
tile_K[j_cwd*D2_padded + k] = B.x[l];
}
}
const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset
const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
}
__syncthreads();
@@ -337,11 +409,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
static_assert(np == 1 || np == 2 || np == 4, "bad np");
if (np == 1) {
// No combination is needed, the meta data can be directly written from registers to VRAM.
if (needs_fixup && threadIdx.x < mma_B::J) {
if (needs_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
if (is_fixup && threadIdx.x < mma_B::J) {
if (is_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
@@ -350,42 +422,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// Warps with threadIdx.y % np != 0 must NOT return early.
// All threads must return simultaneously to avoid race conditions with work on the next tile.
float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2;
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
KQ_cm = meta_j[0];
}
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
#pragma unroll
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
}
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
KQ_crs = KQ_cms*meta_j[1];
}
#pragma unroll
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
}
// Write back combined meta data:
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
meta_j[0] = KQ_cmn; // Combined max. KQ values.
meta_j[1] = KQ_crs; // Combined KQ rowsums.
meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
*((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum.
}
if (needs_fixup && threadIdx.x < mma_B::J) {
if (needs_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
if (is_fixup && threadIdx.x < mma_B::J) {
if (is_fixup && threadIdx.x < tile_B::I) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
}
@@ -404,6 +474,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_j = WARP_SIZE / stride_k;
if (k0_start == k0_stop) {
continue;
}
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
break;
}
@@ -411,12 +485,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I;
if (!is_fixup && jt*ncols + j_dst >= ne01) {
continue;
}
const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2;
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -424,8 +498,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
float2 dstk_val = make_float2(0.0f, 0.0f);
#pragma unroll
for (int ip = 0; ip < np; ++ip) {
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0];
const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]);
dstk_val.x += dstk_val_add.x*KQ_crs;
dstk_val.y += dstk_val_add.y*KQ_crs;
}
@@ -450,7 +524,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
}
#else
NO_DEVICE_CODE;
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
@@ -494,6 +568,11 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#ifndef NEW_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // NEW_MMA_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@@ -504,6 +583,10 @@ static __global__ void flash_attn_ext_f16(
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const int stride_Q = nb01 / sizeof(float2);
const int stride_KV = nb11 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half);
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
@@ -535,14 +618,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
}
kbc += iter_k;
@@ -571,24 +652,27 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
}
template <int D, int cols_per_block>
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
typedef mma_A_I16K8<half2> mma_A;
typedef mma_B_J8K8<half2> mma_B;
typedef tile<16, 8, half2> tile_A;
typedef tile< 8, 8, half2> tile_B;
static_assert(D % mma_B::K == 0, "bad D");
static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
static_assert(D % tile_B::J == 0, "bad D");
static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block");
const ggml_tensor * KQV = dst;
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
constexpr int KQ_stride = D <= 128 ? 64 : 32;
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
constexpr int KQ_stride = D <= 128 ? 64 : 32;
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8);
const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride;
const int nrows_combine = nwarps*tile_B::J;
const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+185 -324
View File
@@ -4,11 +4,12 @@
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
//
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
// A is a row-major matrix with shape I x K.
// B is a column-major matrix with shape K x J.
// C is a column-major matrix with shape I x J.
// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
// A is a row-major matrix with shape M x K.
// B is a column-major matrix with shape K x N.
// C is a column-major matrix with shape M x N.
// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
// Note that J is measured in physical 32 bit elements instead of logical elements.
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
// All matrix tiles have ne physical 32 bit elements per warp.
//
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
@@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
#ifdef NEW_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(ret) : "r"(x));
: "=r"(ret) : "r"(x));
#else
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
@@ -52,407 +53,267 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
#endif // CUDART_VERSION >= 11080
static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
half2 ret;
*((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
return ret;
}
template <typename T>
struct mma_A_I16K4 {
static_assert(sizeof(T) == 4, "bad type size");
namespace ggml_cuda_mma {
static constexpr int I = 16;
static constexpr int K = 4;
static constexpr int ne = 2;
template <int I_, int J_, typename T>
struct tile {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
T x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && (J == 4 || J == 8)) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 8 && J == 8) {
return 4 * l + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x % 4) + l % 2;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
};
static __device__ __forceinline__ int get_k(const int /* l */) {
const int ret = threadIdx.x % K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
template <int I_, int J_>
struct tile<I_, J_, half2> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}};
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return l * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return l * 4 + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 4 + threadIdx.x % 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
};
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
}
return ret;
}
static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
tile<8, 8, half2> ret;
ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
return ret;
}
template <int I, int J, typename T>
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(xi[0]), "+r"(xi[1])
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "=r"(xi[0]), "=r"(xi[1])
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "=r"(xi[0]), "=r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
template <typename T>
struct mma_A_I16K8 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int I = 16;
static constexpr int K = 8;
static constexpr int ne = 4;
T x[ne];
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_k(const int l) {
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
: "l"(xs));
#else
GGML_UNUSED(t);
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
: "l"(xs));
#else
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void transpose() {
int * xi = (int *) x;
xi[0] = ggml_cuda_movmatrix(xi[0]);
const int tmp = ggml_cuda_movmatrix(xi[1]);
xi[1] = ggml_cuda_movmatrix(xi[2]);
xi[2] = tmp;
xi[3] = ggml_cuda_movmatrix(xi[3]);
}
};
template <typename T>
struct mma_B_J8K4 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int J = 8;
static constexpr int K = 4;
static constexpr int ne = 1;
T x[ne];
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
static __device__ __forceinline__ int get_k(const int /* l */) {
const int ret = threadIdx.x % K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
: "+r"(xi[0]) : "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
template <typename T>
struct mma_B_J8K8 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int J = 8;
static constexpr int K = 8;
static constexpr int ne = 2;
T x[ne];
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
static __device__ __forceinline__ int get_k(const int l) {
const int ret = l * (K/2) + threadIdx.x % (K/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(xi[0]), "+r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
template <typename T>
struct mma_C_I16J8 {};
template <>
struct mma_C_I16J8<int> {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;
int x[ne] = {0};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int l) {
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
: "+r"(D.x[0]), "+r"(D.x[1])
: "r"(A.x[0]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
: "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[1]), "r"(B.x[0]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
#else
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
: "+r"(D.x[0]), "+r"(D.x[1])
: "r"(A.x[0]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
: "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[1]), "r"(B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
: "+r"(D.x[0]), "+r"(D.x[1])
: "r"(A.x[2]), "r"(B.x[1]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
: "+r"(D.x[2]), "+r"(D.x[3])
: "r"(A.x[3]), "r"(B.x[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
};
template <>
struct mma_C_I16J8<half2> {
static constexpr int I = 16;
static constexpr int J = 4;
static constexpr int ne = 2;
half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = l * (I/2) + threadIdx.x / J;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x % J;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
static __device__ __forceinline__ void mma(
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
int * Axi = (int *) mma_A.x;
int * Bxi = (int *) mma_B.x;
int * xi = (int *) x;
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "+r"(Dxi[0]), "+r"(Dxi[1])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
mma_B_J8K8<half2> mma_B;
int * xi = (int *) x;
int * Bxi = (int *) mma_B.x;
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
return mma_B;
}
};
template <>
struct mma_C_I16J8<float> {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;
float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int l) {
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
int * Axi = (int *) mma_A.x;
int * Bxi = (int *) mma_B.x;
int * xi = (int *) x;
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
mma_B_J8K8<half2> mma_B;
mma_B.x[0] = make_half2(x[0], x[1]);
mma_B.x[1] = make_half2(x[2], x[3]);
int * Bxi = (int *) mma_B.x;
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
return mma_B;
}
__device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_i(l)];
}
}
};
}
+140 -138
View File
@@ -7,6 +7,8 @@
#include <climits>
#include <cstdint>
using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_NWARPS 8
@@ -647,15 +649,15 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
typedef mma_A_I16K8<int> mma_A;
typedef mma_B_J8K8<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 8, int> tile_A;
typedef tile< 8, 8, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
@@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const float * y_df = (const float *) y;
const half2 * y_ds = (const half2 *) y;
mma_A A[ntx][WARP_SIZE/QI8_0];
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
tile_A A[ntx][WARP_SIZE/QI8_0];
float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
@@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;
A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
@@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
mma_B B;
float dB[mma_C::ne/2];
tile_B B;
float dB[tile_C::ne/2];
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
C.mma(A[n][k01/QI8_0], B);
tile_C C;
mma(C, A[n][k01/QI8_0], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
}
}
}
@@ -758,23 +760,23 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
typedef mma_A_I16K8<int> mma_A;
typedef mma_B_J8K8<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 8, int> tile_A;
typedef tile< 8, 8, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const half2 * y_dm = (const half2 *) y;
mma_A A[ntx][WARP_SIZE/QI8_1];
float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
tile_A A[ntx][WARP_SIZE/QI8_1];
float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
@@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
const int k0 = k00 + k01;
A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
@@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
mma_B B;
float2 dsB[mma_C::ne/2];
tile_B B;
float2 dsB[tile_C::ne/2];
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
C.mma(A[n][k01/QI8_1], B);
tile_C C;
mma(C, A[n][k01/QI8_1], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
}
}
}
@@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
typedef mma_A_I16K4<int> mma_A;
typedef mma_A_I16K8<int> mma_A_K8;
typedef mma_B_J8K4<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
typedef tile< 8, 4, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
mma_A A[ntx][8];
float dA[ntx][mma_C::ne/2][8];
tile_A A[ntx][8];
float dA[ntx][tile_C::ne/2][8];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
@@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
const int k0 = k00 + k01;
((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
@@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
mma_B B[2];
float dB[mma_C::ne/2];
tile_B B[2];
float dB[tile_C::ne/2];
// Here load_generic is faster than load_ldmatrix.
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C[2];
C[0].mma(A[n][k01/4 + 0], B[0]);
C[1].mma(A[n][k01/4 + 1], B[1]);
tile_C C[2];
mma(C[0], A[n][k01/4 + 0], B[0]);
mma(C[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
}
}
}
@@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
typedef mma_A_I16K4<int> mma_A;
typedef mma_A_I16K8<int> mma_A_K8;
typedef mma_B_J8K4<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
typedef tile< 8, 4, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
mma_A A[ntx][8];
float dA[ntx][mma_C::ne/2][8];
float mA[ntx][mma_C::ne/2][8];
tile_A A[ntx][8];
float dA[ntx][tile_C::ne/2][8];
float mA[ntx][tile_C::ne/2][8];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
@@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
const int k0 = k00 + k01;
((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
}
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
@@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
float2 dB[mma_C::ne/2];
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
float2 dB[tile_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
mma_B B[2];
tile_B B[2];
// Here load_generic is faster than load_ldmatrix.
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
mma_C Cm[2];
tile_C Cm[2];
if (k01 >= WARP_SIZE * 3/4) {
mma_A A1;
tile_A A1;
A1.x[0] = 0x01010101;
A1.x[1] = 0x01010101;
Cm[0].mma(A1, B[0]);
Cm[1].mma(A1, B[1]);
mma(Cm[0], A1, B[0]);
mma(Cm[1], A1, B[1]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C Cd[2];
tile_C Cd[2];
Cd[0].mma(A[n][k01/4 + 0], B[0]);
Cd[1].mma(A[n][k01/4 + 1], B[1]);
mma(Cd[0], A[n][k01/4 + 0], B[0]);
mma(Cd[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
for (int l = 0; l < tile_C::ne; ++l) {
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
if (k01 >= WARP_SIZE * 3/4) {
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
}
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
}
}
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
float2 sB[mma_C::ne/2];
float2 sB[tile_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
@@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
}
}
}
@@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
typedef mma_A_I16K4<int> mma_A;
typedef mma_B_J8K4<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 4, int> tile_A;
typedef tile< 8, 4, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
@@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
mma_A A[ntx][8];
int scA[ntx][mma_C::ne/2][8];
float dA[ntx][mma_C::ne/2];
tile_A A[ntx][8];
int scA[ntx][tile_C::ne/2][8];
float dA[ntx][tile_C::ne/2];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
@@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
const int k0 = k00 + k01;
A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
}
#pragma unroll
@@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int k0 = k00 + k01;
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
const int8_t * sc = (const int8_t *) &sc_packed;
@@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
float tmp[ntx][mma_C::ne] = {{0.0f}};
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
float tmp[ntx][tile_C::ne] = {{0.0f}};
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
mma_B B[2];
float dB[mma_C::ne/2];
tile_B B[2];
float dB[tile_C::ne/2];
// Here load_generic is faster than load_ldmatrix.
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C[2];
C[0].mma(A[n][k01/4 + 0], B[0]);
C[1].mma(A[n][k01/4 + 1], B[1]);
tile_C C[2];
mma(C[0], A[n][k01/4 + 0], B[0]);
mma(C[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
for (int l = 0; l < tile_C::ne; ++l) {
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
}
}
@@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
}
}
}
@@ -2312,36 +2314,36 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_mma(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
#ifdef NEW_MMA_AVAILABLE
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
#endif // NEW_MMA_AVAILABLE
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne; ++l) {
const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
if (j > j_max) {
continue;
}
const int i = i0 + n*mma_C::I + mma_C::get_i(l);
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
if (need_check && i > i_max) {
continue;
}
dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
}
}
}
+1 -1
View File
@@ -24,7 +24,7 @@
#endif
// create residency sets only on macOS >= 15.0
#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,51 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
shared uint tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
if (col >= p.KX) {
return;
}
A_TYPE amax = data_a[row*p.KX + col];
tmp[col] = col;
for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
A_TYPE val = data_a[row*p.KX + i];
if (val > amax) {
amax = val;
tmp[col] = i;
}
}
tmpmax[col] = amax;
barrier();
[[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
if (col < s && col + s < p.KX) {
if (tmpmax[col] < tmpmax[col + s]) {
tmpmax[col] = tmpmax[col + s];
tmp[col] = tmp[col + s];
}
}
barrier();
}
if (col == 0) {
data_d[row] = D_TYPE(tmp[0]);
}
}
@@ -0,0 +1,31 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.comp"
#include "generic_head.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) buffer D {D_TYPE data_d[];};
const uint CHUNK_SIZE = 512;
void main() {
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
const uint col = gl_LocalInvocationID.x;
uint count = 0;
[[unroll]]
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
const uint idx = base + i + col;
if (idx >= p.KX) {
break;
}
count += uint(data_a[idx] == data_b[idx]);
}
atomicAdd(data_d[0], D_TYPE(count));
}
@@ -0,0 +1,42 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) buffer X {A_TYPE x[];};
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
layout (binding = 2) buffer GM {A_TYPE gradm[];};
layout (binding = 3) buffer GV {A_TYPE gradv[];};
layout (binding = 4) readonly buffer P {float params[7];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float alpha = params[0];
const float beta1 = params[1];
const float beta2 = params[2];
const float eps = params[3];
const float wd = params[4];
const float beta1h = params[5];
const float beta2h = params[6];
const float gi = grad[i];
const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
gradm[i] = gmi;
gradv[i] = gvi;
const float mh = gmi*beta1h;
const float vh = sqrt(gvi*beta2h) + eps;
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
}
@@ -0,0 +1,37 @@
#version 450
#include "types.comp"
#include "generic_unary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
// Destination multi-index (inlined dst_idx)
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
// Accumulate from sources
A_TYPE acc = A_TYPE(0);
for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
}
}
}
}
data_d[get_doffset() + d_idx] = D_TYPE(acc);
}
@@ -25,6 +25,10 @@ layout (push_constant) uniform parameter {
float corr_dims[2];
float theta_scale;
uint has_ff;
uint ne02;
uint s1;
uint s2;
int sections[4];
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {
@@ -0,0 +1,60 @@
#version 450
#include "rope_head.comp"
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;
if (i0 >= ne0) {
return;
}
const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
return;
}
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + p.n_dims/2]);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
@@ -3,15 +3,18 @@
#include "rope_head.comp"
void main() {
const uint col = gl_GlobalInvocationID.y * 2;
const uint row = gl_GlobalInvocationID.x;
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
if (col >= p.ncols) {
if (i0 >= ne0) {
return;
}
if (col >= p.n_dims) {
const uint i = row*p.ncols + col;
const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
@@ -19,19 +22,22 @@ void main() {
return;
}
const uint i = row*p.ncols + col/2;
const uint i2 = row/p.p_delta_rows;
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + p.n_dims/2]);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + p.n_dims/2]);
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
@@ -3,15 +3,18 @@
#include "rope_head.comp"
void main() {
const uint col = gl_GlobalInvocationID.y * 2;
const uint row = gl_GlobalInvocationID.x;
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
if (col >= p.ncols) {
if (i0 >= ne0) {
return;
}
if (col >= p.n_dims) {
const uint i = row*p.ncols + col;
const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
@@ -19,19 +22,22 @@ void main() {
return;
}
const uint i = row*p.ncols + col;
const uint i2 = row/p.p_delta_rows;
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
const uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + 1]);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + 1]);
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
@@ -0,0 +1,47 @@
#version 450
#include "rope_head.comp"
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;
if (i0 >= ne0) {
return;
}
const uint row_dst = gl_GlobalInvocationID.x;
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
const int sect_dims = p.sections[0] + p.sections[1];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < p.sections[0]) {
const uint p0 = sector;
theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
}
else if (sector >= p.sections[0] && sector < sec_w) {
const uint p0 = sector - p.sections[0];
theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
}
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + p.n_dims]);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
}
@@ -0,0 +1,29 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#include "types.comp"
#include "generic_binary_head.comp"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
}
@@ -443,6 +443,8 @@ void process_shaders() {
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
@@ -452,6 +454,7 @@ void process_shaders() {
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
@@ -491,9 +494,19 @@ void process_shaders() {
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
@@ -505,6 +518,8 @@ void process_shaders() {
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
for (auto &c : compiles) {
c.wait();
}
+30 -17
View File
@@ -124,9 +124,22 @@ if input_file is None:
connection = sqlite3.connect(input_file)
cursor = connection.cursor()
builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
commit_short_len = len(builds[0][0])
build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
if build_len_min != build_len_max:
logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
"Try purging the the database of old commits.")
cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});")
build_len: int = build_len_min
builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
if not builds:
raise RuntimeError(f"{input_file} does not contain any builds.")
try:
repo = git.Repo(".", search_parent_directories=True)
@@ -140,11 +153,11 @@ def find_parent_in_data(commit: git.Commit):
seen_hexsha8 = set()
while heap:
depth, current_commit = heapq.heappop(heap)
current_hexsha8 = commit.hexsha[:commit_short_len]
if (current_hexsha8,) in builds:
current_hexsha8 = commit.hexsha[:build_len]
if current_hexsha8 in builds:
return current_hexsha8
for parent in commit.parents:
parent_hexsha8 = parent.hexsha[:commit_short_len]
parent_hexsha8 = parent.hexsha[:build_len]
if parent_hexsha8 not in seen_hexsha8:
seen_hexsha8.add(parent_hexsha8)
heapq.heappush(heap, (depth + 1, parent))
@@ -158,40 +171,40 @@ def get_all_parent_hexsha8s(commit: git.Commit):
while unvisited:
current_commit = unvisited.pop(0)
visited.append(current_commit.hexsha[:commit_short_len])
visited.append(current_commit.hexsha[:build_len])
for parent in current_commit.parents:
if parent.hexsha[:commit_short_len] not in visited:
if parent.hexsha[:build_len] not in visited:
unvisited.append(parent)
return visited
def get_commit_name(hexsha8):
def get_commit_name(hexsha8: str):
"""Helper function to find a human-readable name for a commit if possible."""
if repo is None:
return hexsha8
for h in repo.heads:
if h.commit.hexsha[:commit_short_len] == hexsha8:
if h.commit.hexsha[:build_len] == hexsha8:
return h.name
for t in repo.tags:
if t.commit.hexsha[:commit_short_len] == hexsha8:
if t.commit.hexsha[:build_len] == hexsha8:
return t.name
return hexsha8
def get_commit_hexsha8(name):
def get_commit_hexsha8(name: str):
"""Helper function to search for a commit given a human-readable name."""
if repo is None:
return None
for h in repo.heads:
if h.name == name:
return h.commit.hexsha[:commit_short_len]
return h.commit.hexsha[:build_len]
for t in repo.tags:
if t.name == name:
return t.commit.hexsha[:commit_short_len]
return t.commit.hexsha[:build_len]
for c in repo.iter_commits("--all"):
if c.hexsha[:commit_short_len] == name[:commit_short_len]:
return c.hexsha[:commit_short_len]
if c.hexsha[:build_len] == name[:build_len]:
return c.hexsha[:build_len]
return None
@@ -199,7 +212,7 @@ hexsha8_baseline = name_baseline = None
# If the user specified a baseline, try to find a commit for it:
if known_args.baseline is not None:
if (known_args.baseline,) in builds:
if known_args.baseline in builds:
hexsha8_baseline = known_args.baseline
if hexsha8_baseline is None:
hexsha8_baseline = get_commit_hexsha8(known_args.baseline)
@@ -228,7 +241,7 @@ hexsha8_compare = name_compare = None
# If the user has specified a compare value, try to find a corresponding commit:
if known_args.compare is not None:
if (known_args.compare,) in builds:
if known_args.compare in builds:
hexsha8_compare = known_args.compare
if hexsha8_compare is None:
hexsha8_compare = get_commit_hexsha8(known_args.compare)
+1 -1
View File
@@ -21,7 +21,7 @@ def get_chat_template(model_id, variant=None):
# Use huggingface_hub library if available.
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
from huggingface_hub import hf_hub_download
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json"), encoding="utf-8") as f:
config_str = f.read()
except ImportError:
import requests
+178 -178
View File
@@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
size_t last_sym_start = rule.size();
const char * pos = src;
auto handle_repetitions = [&](int min_times, int max_times) {
auto handle_repetitions = [&](int min_times, int max_times) {
if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}
if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}
// apply transformation to previous symbol (last_sym_start to end) according to
// the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...)
// S'(1) ::= S |
// S{m,} --> S S S (m times) S'
// S' ::= S S' |
// S* --> S{0,}
// --> S' ::= S S' |
// S+ --> S{1,}
// --> S S'
// S' ::= S S' |
// S? --> S{0,1}
// --> S'
// S' ::= S |
// apply transformation to previous symbol (last_sym_start to end) according to
// the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...)
// S'(1) ::= S |
// S{m,} --> S S S (m times) S'
// S' ::= S S' |
// S* --> S{0,}
// --> S' ::= S S' |
// S+ --> S{1,}
// --> S S'
// S' ::= S S' |
// S? --> S{0,1}
// --> S'
// S' ::= S |
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
if (min_times == 0) {
rule.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
}
}
uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = rule.size();
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = rule.size();
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
int max_times = -1;
if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
handle_repetitions(min_times, max_times);
} else {
break;
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
if (min_times == 0) {
rule.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
}
}
return pos;
uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = rule.size();
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = rule.size();
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
int max_times = -1;
if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
handle_repetitions(min_times, max_times);
} else {
break;
}
}
return pos;
}
const char * llama_grammar_parser::parse_rule(const char * src) {
const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(src, name_len);
const std::string name(src, name_len);
const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(src, name_len);
const std::string name(src, name_len);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);
pos = parse_alternates(pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);
pos = parse_alternates(pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
}
bool llama_grammar_parser::parse(const char * src) {
try {
+5 -5
View File
@@ -1254,7 +1254,7 @@ struct test_count_equal : public test_case {
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");
ggml_tensor * b_argmax = ggml_argmax(ctx, a);
ggml_tensor * b_argmax = ggml_argmax(ctx, b);
ggml_set_name(b_argmax, "b_argmax");
ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
@@ -1511,6 +1511,7 @@ struct test_cont : public test_case {
};
// GGML_OP_ADD
// GGML_OP_SUB
// GGML_OP_MUL
// GGML_OP_DIV
struct test_bin_bcast : public test_case {
@@ -3860,7 +3861,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
test_cases.emplace_back(new test_count_equal());
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
@@ -3885,8 +3887,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
}
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
@@ -3938,7 +3938,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
for (auto op : {ggml_add, ggml_mul, ggml_div}) {
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
}
};
+32 -22
View File
@@ -1,13 +1,14 @@
#include <string>
#include <vector>
#include <sstream>
#include <regex>
#undef NDEBUG
#include <cassert>
#include "llama.h"
#include "common.h"
#include "chat-template.hpp"
#include "chat.h"
static std::string normalize_newlines(const std::string & s) {
#ifdef _WIN32
@@ -18,6 +19,13 @@ static std::string normalize_newlines(const std::string & s) {
#endif
}
static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
common_chat_msg msg;
msg.role = role;
msg.content = content;
return msg;
}
int main(void) {
std::vector<llama_chat_message> conversation {
{"system", "You are a helpful assistant"},
@@ -50,7 +58,7 @@ int main(void) {
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .bos_token= */ "<s>",
/* .eos_token= */ "</s>",
},
{
@@ -72,8 +80,8 @@ int main(void) {
{
/* .name= */ "mlabonne/AlphaMonarch-7B",
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
/* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "<s>",
/* .eos_token= */ "</s>",
},
@@ -87,7 +95,7 @@ int main(void) {
/* .name= */ "OrionStarAI/Orion-14B-Chat",
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: ",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
@@ -304,12 +312,9 @@ int main(void) {
}
}
json messages = json::array();
std::vector<common_chat_msg> messages;
for (const auto & msg : conversation) {
messages.push_back({
{"role", msg.role},
{"content", msg.content},
});
messages.push_back(simple_msg(msg.role, msg.content));
}
for (const auto & test_case : test_cases) {
if (!test_case.supported_with_jinja) {
@@ -317,8 +322,13 @@ int main(void) {
}
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
try {
minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token);
auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt));
auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
common_chat_templates_inputs inputs;
inputs.use_jinja = true;
inputs.messages = messages;
inputs.add_generation_prompt = add_generation_prompt;
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
output = normalize_newlines(output);
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
if (output != expected_output) {
printf("Expected:\n%s\n", expected_output.c_str());
@@ -336,11 +346,11 @@ int main(void) {
// test llama_chat_format_single for system message
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
std::vector<common_chat_msg> chat2;
common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
auto sys_msg = simple_msg("system", "You are a helpful assistant");
auto fmt_sys = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", "");
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
return output;
@@ -360,14 +370,14 @@ int main(void) {
// test llama_chat_format_single for user message
printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
chat2.push_back({"system", "You are a helpful assistant", {}});
chat2.push_back({"user", "Hello", {}});
chat2.push_back({"assistant", "I am assistant", {}});
common_chat_msg new_msg{"user", "How are you", {}};
chat2.push_back(simple_msg("system", "You are a helpful assistant"));
chat2.push_back(simple_msg("user", "Hello"));
chat2.push_back(simple_msg("assistant", "I am assistant"));
auto new_msg = simple_msg("user", "How are you");
auto fmt_single = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", "");
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
auto fmt_single = [&](const std::string & tmpl_str) {
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
return output;
+446 -387
View File
File diff suppressed because it is too large Load Diff