Compare commits

...

15 Commits

Author SHA1 Message Date
Łukasz Ślusarczyk 9c404ed54c sycl: use oneDNN for matrices multiplication (#12972) 2025-05-15 16:53:41 +02:00
Diego Devesa 6c8b91500e llama-bench : fix -ot with dl backends (#13563) 2025-05-15 15:46:55 +02:00
Xuan-Son Nguyen 3cc1f1f1d2 webui : handle PDF input (as text or image) + convert pasted long content to file (#13562)
* webui : handle PDF input (as text or image)

* handle the case where pdf image + server without mtmd

* fix bug missing pages
2025-05-15 14:24:50 +02:00
Piotr Wilkin (ilintar) c753d7bed0 server : proper error handling for missing elements in messages array (OpenAI compatible backend) (#13540) 2025-05-15 08:40:58 +02:00
Georgi Gerganov b2838049cc bench : handle decode errors (#13548)
ggml-ci
2025-05-15 05:57:02 +03:00
Olivier Chafik aa48e373f2 server: inject date_string in llama 3.x template + fix date for firefunction v2 (#12802)
* Inject date_string in llama 3.x + fix for functionary v2

https://github.com/ggml-org/llama.cpp/issues/12729

* move/fix detection of functionary v3.1 before llama 3.x, fix & test their non-tool mode

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* generate more tokens in test_completion_with_required_tool_tiny_fast to avoid truncation

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-05-15 02:39:51 +01:00
Georgi Gerganov e3a9421b78 kv-cache : fix out-of-bounds view during reserve graph (#13547)
* kv-cache : fix reserve graph out-of-bounds access

ggml-ci

* cont : add comment

* cont : fix comments [no ci]

* cont : more correct comment [no ci]
2025-05-14 23:15:15 +03:00
Yibo Cai 5ab5d5fb25 arm64: optimize q6_k_q8_k kernel with i8mm (#13519)
This PR improves q6_k_q8_k gemm kernel with arm64 i8mm instruction.

Tested on neoverse-n2 with llama3 8b q6_k quantization model.
- 40% ~ 54% S_PP uplift for all batch sizes
- 16% ~ 47% S_TG uplift for batch size 4 and above

Perplexity doesn't change with this PR.

```
// tested on neoverse-n2
$ llama-batched-bench \
      -m Meta-Llama-3-8B-Instruct-Q6_K.gguf \
      --no-mmap -fa \
      -c 8192 -b 4096 -ub 512 -npp 128 -ntg 128 \
      -npl 1,2,4,8,16,32 \
      -t 64

---------------------------------------------------------------------
|    PP |     TG |    B |       S_PP t/s      |       S_TG t/s      |
|       |        |      | original |  this pr | original |  this pr |
|-------|--------|------|----------|----------|----------|----------|
|   128 |    128 |    1 |    78.52 |   109.18 |    18.63 |    18.88 |
|   128 |    128 |    2 |    84.62 |   123.94 |    34.54 |    36.92 |
|   128 |    128 |    4 |    84.36 |   122.49 |    52.65 |    61.32 |
|   128 |    128 |    8 |    90.52 |   138.87 |    63.46 |    84.41 |
|   128 |    128 |   16 |    90.11 |   138.56 |    71.04 |   101.33 |
|   128 |    128 |   32 |    89.81 |   137.79 |    75.14 |   110.47 |
---------------------------------------------------------------------
```
2025-05-14 21:53:52 +02:00
Olivier Chafik 3198405e98 common: add partial regex support (#12808)
* move string_find_partial_stop & string_ends_with to common

* add common_regex (supports partial matches)

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update common/regex-partial.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update common/regex-partial.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update common/regex-partial.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* partial regex: add missing iterator end checks

* string utils: use string_views

* direct throw to avoid ggml.h include

* regex-partial: replace missed ggml_asserts

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-05-14 19:50:57 +01:00
Sigbjørn Skjæret f5170c1d7a editorconfig : fix trailing whitespace from #13542 (#13546) 2025-05-14 21:22:49 +03:00
Gilad S. 017f10b5fa fix: crash when calling llama_state_get_size on a context without a KV cache (#13542) 2025-05-14 19:18:18 +03:00
Johannes Gäßler 4696d56749 CUDA: fix crash on large batch size for quant. MoE (#13537) 2025-05-14 16:41:02 +02:00
Diego Devesa b7d2672082 llama : fix quantize with dl backends (#13539) 2025-05-14 16:12:36 +02:00
Johannes Gäßler 6da34fa276 CUDA: faster Deepseek FA, add Turing support (#13435) 2025-05-14 16:08:20 +02:00
Gabe Goodhart 5e7d95e22e fix: Move build_inp_pos to the top of the graph section for build_granite (#13538)
This matches how others do it, but will still avoid the extra
initialization when rope is disabled.

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-05-14 15:53:59 +03:00
41 changed files with 1950 additions and 348 deletions
+2
View File
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
minja/minja.hpp
ngram-cache.cpp
ngram-cache.h
regex-partial.cpp
regex-partial.h
sampling.cpp
sampling.h
speculative.cpp
+130 -110
View File
@@ -6,6 +6,15 @@
#include <optional>
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
auto time = std::chrono::system_clock::to_time_t(now);
auto local_time = *std::localtime(&time);
std::ostringstream ss;
ss << std::put_time(&local_time, format.c_str());
auto res = ss.str();
return res;
}
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
@@ -24,6 +33,7 @@ struct templates_params {
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
};
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -939,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
}
}
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array();
common_chat_params data;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
if (!inputs.tools.is_null()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
expect_tool_parameters(name, parameters, {"code"});
} else {
return false;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
expect_tool_parameters(name, parameters, {"code"});
} else {
return false;
}
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
builtin_tools.push_back(name);
return true;
};
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
if (allow_python_tag_builtin_tools) {
handle_builtin_tool(name, parameters);
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"{\" space "
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
"\"}\" space"));
});
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
});
if (!builtin_tools.empty()) {
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
builtin_tools.push_back(name);
return true;
};
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
if (allow_python_tag_builtin_tools) {
handle_builtin_tool(name, parameters);
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"{\" space "
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
"\"}\" space"));
// Allow a few empty lines on top of the usual constrained json schema space rule.
builder.add_rule("root", string_join(tool_rules, " | "));
data.additional_stops.push_back("<|eom_id|>");
});
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
});
if (!builtin_tools.empty()) {
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
// Allow a few empty lines on top of the usual constrained json schema space rule.
builder.add_rule("root", string_join(tool_rules, " | "));
});
data.additional_stops.push_back("<|eom_id|>");
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
: COMMON_CHAT_FORMAT_LLAMA_3_X;
} else {
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
}
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
{"date_string", format_time(inputs.now, "%d %b %Y")},
{"tools_in_user_message", false},
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
});
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
: COMMON_CHAT_FORMAT_LLAMA_3_X;
return data;
}
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
@@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
LOG_DBG("%s\n", __func__);
common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"},
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
});
if (inputs.tools.is_array() && !inputs.tools.empty()) {
@@ -1285,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
const auto & parameters = function.at("parameters");
std::string name = function.at("name");
if (name == "python" || name == "ipython") {
if (!parameters.contains("type")) {
throw std::runtime_error("Missing type in python tool");
}
has_raw_python = true;
const auto & type = parameters.at("type");
if (type == "object") {
auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
if (it.value().at("type") == "string") {
if (!python_code_argument_name.empty()) {
throw std::runtime_error("Multiple string arguments found in python tool");
if (!inputs.tools.is_null()) {
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
const auto & parameters = function.at("parameters");
std::string name = function.at("name");
if (name == "python" || name == "ipython") {
if (!parameters.contains("type")) {
throw std::runtime_error("Missing type in python tool");
}
has_raw_python = true;
const auto & type = parameters.at("type");
if (type == "object") {
auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
if (it.value().at("type") == "string") {
if (!python_code_argument_name.empty()) {
throw std::runtime_error("Multiple string arguments found in python tool");
}
python_code_argument_name = it.key();
}
python_code_argument_name = it.key();
}
if (python_code_argument_name.empty()) {
throw std::runtime_error("No string argument found in python tool");
}
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
}
if (python_code_argument_name.empty()) {
throw std::runtime_error("No string argument found in python tool");
}
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
}
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
});
if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
});
if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
});
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
} else {
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
}
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
// TODO: if (has_raw_python)
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
return data;
}
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
@@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
params.extract_reasoning = inputs.extract_reasoning;
params.tool_choice = inputs.tool_choice;
params.grammar = inputs.grammar;
params.now = inputs.now;
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
@@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_firefunction_v2(tmpl, params);
}
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
}
// Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
}
// Llama 3.1, 3.2, 3.3 (w/ tools)
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
}
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
}
// Mistral Nemo (w/ tools)
+2
View File
@@ -3,6 +3,7 @@
#pragma once
#include "common.h"
#include <chrono>
#include <string>
#include <vector>
@@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
bool extract_reasoning = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
};
struct common_chat_params {
+19
View File
@@ -443,6 +443,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
if (!str.empty() && !stop.empty()) {
const char text_last_char = str.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const auto current_partial = stop.substr(0, char_index + 1);
if (string_ends_with(str, current_partial)) {
return str.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0");
+4 -4
View File
@@ -6,6 +6,7 @@
#include <set>
#include <string>
#include <string_view>
#include <vector>
#include <sstream>
@@ -503,10 +504,9 @@ static bool string_starts_with(const std::string & str,
return str.rfind(prefix, 0) == 0;
}
static bool string_ends_with(const std::string & str,
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}
// While we wait for C++20's std::string::ends_with...
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);
+204
View File
@@ -0,0 +1,204 @@
#include "regex-partial.h"
#include "common.h"
#include <functional>
#include <optional>
common_regex::common_regex(const std::string & pattern) :
pattern(pattern),
rx(pattern),
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
std::smatch match;
if (pos > input.size()) {
throw std::runtime_error("Position out of bounds");
}
auto start = input.begin() + pos;
auto found = as_match
? std::regex_match(start, input.end(), match, rx)
: std::regex_search(start, input.end(), match, rx);
if (found) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
for (size_t i = 0; i < match.size(); ++i) {
auto begin = pos + match.position(i);
res.groups.emplace_back(begin, begin + match.length(i));
}
return res;
}
std::match_results<std::string::const_reverse_iterator> srmatch;
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
auto group = srmatch[1].str();
if (group.length() != 0) {
auto it = srmatch[1].second.base();
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
if ((!as_match) || it == input.begin()) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
const size_t begin = std::distance(input.begin(), it);
const size_t end = input.size();
if (begin == std::string::npos || end == std::string::npos || begin > end) {
throw std::runtime_error("Invalid range");
}
res.groups.push_back({begin, end});
return res;
}
}
}
return {};
}
/*
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
- /a|b/ -> (a|b).*
- /a*?/ -> error, could match ""
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
- /.*?ab/ -> ((?:b)?a).* (merge .*)
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
*/
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin();
const auto end = pattern.end();
std::function<std::string()> process = [&]() {
std::vector<std::vector<std::string>> alternatives(1);
std::vector<std::string> * sequence = &alternatives.back();
while (it != end) {
if (*it == '[') {
auto start = it;
++it;
while (it != end) {
if ((*it == '\\') && (++it != end)) {
++it;
} else if ((it != end) && (*it == ']')) {
break;
} else {
++it;
}
}
if (it == end) {
throw std::runtime_error("Unmatched '[' in pattern");
}
++it;
sequence->push_back(std::string(start, it));
} else if (*it == '*' || *it == '?' || *it == '+') {
if (sequence->empty()) {
throw std::runtime_error("Quantifier without preceding element");
}
sequence->back() += *it;
auto is_star = *it == '*';
++it;
if (is_star) {
if (*it == '?') {
++it;
}
}
} else if (*it == '{') {
if (sequence->empty()) {
throw std::runtime_error("Repetition without preceding element");
}
++it;
auto start = it;
while (it != end && *it != '}') {
++it;
}
if (it == end) {
throw std::runtime_error("Unmatched '{' in pattern");
}
auto parts = string_split(std::string(start, it), ",");
++it;
if (parts.size() > 2) {
throw std::runtime_error("Invalid repetition range in pattern");
}
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
if (s.empty()) {
return def;
}
return std::stoi(s);
};
auto min = parseOptInt(parts[0], 0);
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
if (min && max && *max < *min) {
throw std::runtime_error("Invalid repetition range in pattern");
}
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
auto part = sequence->back();
sequence->pop_back();
for (int i = 0; i < *min; i++) {
sequence->push_back(part);
}
if (max) {
for (int i = *min; i < *max; i++) {
sequence->push_back(part + "?");
}
} else {
sequence->push_back(part + "*");
}
} else if (*it == '(') {
++it;
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
it += 2;
}
auto sub = process();
if (*it != ')') {
throw std::runtime_error("Unmatched '(' in pattern");
}
++it;
auto & part = sequence->emplace_back("(?:");
part += sub;
part += ")";
} else if (*it == ')') {
break;
} else if (*it == '|') {
++it;
alternatives.emplace_back();
sequence = &alternatives.back();
} else if (*it == '\\' && (++it != end)) {
auto str = std::string("\\") + *it;
sequence->push_back(str);
++it;
} else if (it != end) {
sequence->push_back(std::string(1, *it));
++it;
}
}
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function.
std::vector<std::string> res_alts;
for (const auto & parts : alternatives) {
auto & res = res_alts.emplace_back();
for (size_t i = 0; i < parts.size() - 1; i++) {
res += "(?:";
}
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
res += *it;
if (it != parts.rend() - 1) {
res += ")?";
}
}
}
return string_join(res_alts, "|");
};
auto res = process();
if (it != end) {
throw std::runtime_error("Unmatched '(' in pattern");
}
return "(" + res + ")[\\s\\S]*";
}
+56
View File
@@ -0,0 +1,56 @@
#pragma once
#include <regex>
#include <string>
enum common_regex_match_type {
COMMON_REGEX_MATCH_TYPE_NONE,
COMMON_REGEX_MATCH_TYPE_PARTIAL,
COMMON_REGEX_MATCH_TYPE_FULL,
};
struct common_string_range {
size_t begin;
size_t end;
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
if (begin > end) {
throw std::runtime_error("Invalid range");
}
}
// prevent default ctor
common_string_range() = delete;
bool empty() const {
return begin == end;
}
bool operator==(const common_string_range & other) const {
return begin == other.begin && end == other.end;
}
};
struct common_regex_match {
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
std::vector<common_string_range> groups;
bool operator==(const common_regex_match & other) const {
return type == other.type && groups == other.groups;
}
bool operator!=(const common_regex_match & other) const {
return !(*this == other);
}
};
class common_regex {
std::string pattern;
std::regex rx;
std::regex rx_reversed_partial;
public:
explicit common_regex(const std::string & pattern);
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
const std::string & str() const { return pattern; }
};
// For testing only (pretty print of failures).
std::string regex_to_reversed_partial_regex(const std::string & pattern);
+2
View File
@@ -731,6 +731,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
@@ -741,6 +742,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
+1
View File
@@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
option(GGML_SYCL "ggml: use SYCL" OFF)
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
+195
View File
@@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
#ifdef __ARM_FEATURE_MATMUL_INT8
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1);
#endif
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
@@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int nb = n / QK_K;
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q6_K * GGML_RESTRICT x0 = x;
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
const block_q8_K * GGML_RESTRICT y0 = y;
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
float32x4_t vfsum = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
const int8_t * GGML_RESTRICT qy0 = y0->qs;
const int8_t * GGML_RESTRICT qy1 = y1->qs;
const uint8x16_t mone = vdupq_n_u8(0x30);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
int32x4_t visum = vdupq_n_s32(0);
// process 8 blocks per iteration, totally 16 blocks
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
int8x16_t vx0[8], vx1[8];
// de-quantize vx0[8]
{
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
}
// de-quantize vx1[8]
{
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
}
// process 16 elements (one block with same scale) per iteration
// - vx = concat(ql, qh) - 32
// - r1,r2,r3,r4 = smmla(vx, vy)
for (int k = 0; k < 8; ++k) {
const int blk = j * 8 + k;
const int8x16_t vy0 = vld1q_s8(qy0);
const int8x16_t vy1 = vld1q_s8(qy1);
qy0 += 16;
qy1 += 16;
const int32x4_t block_scale = {
x0->scales[blk],
x0->scales[blk],
x1->scales[blk],
x1->scales[blk],
};
// calculate four results at once with outer product
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
int32x4_t vr = vdupq_n_s32(0);
vr = vmmlaq_s32(vr, vx_l, vy_l);
vr = vmmlaq_s32(vr, vx_h, vy_h);
// apply block scale, will NOT overflow
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
visum = vmlaq_s32(visum, vr, block_scale);
}
}
// adjust bias, apply superblock scale
{
int32_t bias[4];
#ifdef __ARM_FEATURE_SVE
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
const svint64_t zero = svdup_n_s64(0);
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
#else
// NEON doesn't support int16 dot product, fallback to separated mul and add
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
int8x16_t scales_s8 = vld1q_s8(x0->scales);
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
scales_s8 = vld1q_s8(x1->scales);
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
int32x4_t prod;
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
bias[0] = vaddvq_s32(prod);
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
bias[1] = vaddvq_s32(prod);
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
bias[2] = vaddvq_s32(prod);
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
bias[3] = vaddvq_s32(prod);
#endif
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
const float32x4_t superblock_scale = {
GGML_FP16_TO_FP32(x0->d) * y0->d,
GGML_FP16_TO_FP32(x0->d) * y1->d,
GGML_FP16_TO_FP32(x1->d) * y0->d,
GGML_FP16_TO_FP32(x1->d) * y1->d,
};
visum = vsubq_s32(visum, vibias);
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
}
}
// vfsum = ABCD -> ACBD
// AC -> s, BD -> (s+bs)
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
vst1_f32(s, vget_low_f32 (vfsum));
vst1_f32(s + bs, vget_high_f32(vfsum));
return;
}
#endif
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0;
+4
View File
@@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_K,
.vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
.nrows = 1,
#endif
},
[GGML_TYPE_IQ2_XXS] = {
.from_float = NULL,
+13 -5
View File
@@ -678,10 +678,14 @@ void launch_fattn(
) {
constexpr int ncols = ncols1 * ncols2;
const bool is_mla = DV == 512; // TODO better parameterization
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
GGML_ASSERT(V || is_mla);
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
@@ -689,6 +693,10 @@ void launch_fattn(
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
@@ -713,10 +721,10 @@ void launch_fattn(
size_t nb12 = K->nb[2];
size_t nb13 = K->nb[3];
const char * V_data = (const char *) V->data;
size_t nb21 = V->nb[1];
size_t nb22 = V->nb[2];
size_t nb23 = V->nb[3];
const char * V_data = V ? (const char *) V->data : nullptr;
size_t nb21 = V ? V->nb[1] : nb11;
size_t nb22 = V ? V->nb[2] : nb12;
size_t nb23 = V ? V->nb[3] : nb13;
if (need_f16_K && K->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(K));
@@ -733,7 +741,7 @@ void launch_fattn(
nb13 = nb13*bs*sizeof(half)/ts;
}
if (need_f16_V && V->type != GGML_TYPE_F16) {
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(V));
V_f16.alloc(ggml_nelements(V));
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+261 -64
View File
@@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 32;
static constexpr int nbatch_V2 = 32;
static constexpr int nbatch_combine = 32;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 32;
}
};
template <>
@@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 40;
static constexpr int nbatch_V2 = 40;
static constexpr int nbatch_combine = 40;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 40;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 40;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 40;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 40;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 40;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 40;
}
};
template <>
@@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 48;
static constexpr int nbatch_V2 = 48;
static constexpr int nbatch_combine = 48;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 48;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 48;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 48;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 48;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 48;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 48;
}
};
template <>
@@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 56;
static constexpr int nbatch_V2 = 56;
static constexpr int nbatch_combine = 56;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 56;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 56;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 56;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 56;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 56;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 56;
}
};
template <>
@@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 64;
static constexpr int nbatch_V2 = 64;
static constexpr int nbatch_combine = 64;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 64;
}
};
template <>
@@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static constexpr int nbatch_K2 = 128;
static constexpr int nbatch_V2 = 128;
static constexpr int nbatch_combine = 128;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 128;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 128;
}
static int get_nbatch_combine_host(const int cc, const int ncols) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
return ncols <= 16 ? 128 : 64;
}
return 64;
}
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
return ncols <= 16 ? 128 : 64;
#else
GGML_UNUSED(ncols);
return 128;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
}
};
template <>
@@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
static constexpr int nwarps_max = 8;
static constexpr bool Q_in_reg = false;
static constexpr int nstages_target = 1;
static constexpr int nbatch_K2 = 160;
static constexpr int nbatch_V2 = 128;
static constexpr int nbatch_combine = 128;
static int get_nbatch_K2_host(const int cc, const int ncols) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
return ncols <= 16 ? 96 : 160;
}
return ncols <= 16 ? 288 : 160;
}
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
return ncols <= 16 ? 96 : 160;
#else
return ncols <= 16 ? 288 : 160;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
}
static int get_nbatch_V2_host(const int cc, const int ncols) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
return ncols <= 16 ? 64 : 128;
}
return ncols <= 16 ? 256 : 128;
}
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
return ncols <= 16 ? 64 : 128;
#else
return ncols <= 16 ? 256 : 128;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 128;
}
};
// ------------------------------------------------------------------------------------------------------------------
@@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
auto load = [&] __device__ (const int n) {
auto load = [&] __device__ (auto n) {
const int stride_k = WARP_SIZE >> n;
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
@@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
}
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
@@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int ncols = ncols1 * ncols2;
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = c::nbatch_K2 + 4;
constexpr int stride_tile_V = c::nbatch_V2 + 4;
constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = nbatch_K2 + 4;
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
const int k_VKQ_0 = kb0 * c::nbatch_fa;
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
@@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
if constexpr (nstages > 1) {
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
static_assert(!mla, "multi-stage loading not implemented for MLA");
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
constexpr bool use_cp_async = true;
cp_async_wait_all();
__syncthreads();
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
} else {
constexpr bool use_cp_async = nstages == 1;
if (ncols2 > 1 || mask_h2) {
@@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
#pragma unroll
for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
const int k0_diff = k0_stop - k0_start;
if (nstages <= 1) {
@@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
}
}
#pragma unroll
for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
const int i0_diff = i0_stop - i0_start;
if (nstages <= 1) {
// For MLA K and V have the same data.
// Therefore, iterate over V in reverse and re-use the data if possible.
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
#pragma unroll
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
const int i0_diff = i0_stop - i0_start;
if (nstages <= 1 && i0_start < reusable_cutoff) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
@@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
__syncthreads();
}
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
// Calculate VKQ tile:
#pragma unroll
@@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
tile_A A;
load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
if (ntiles == 1) {
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
} else {
@@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#endif // NEW_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
@@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = c::nbatch_K2 + 4;
constexpr int stride_tile_V = c::nbatch_V2 + 4;
constexpr int stride_tile_Q = DKQ/2 + 4;
constexpr int stride_tile_K = nbatch_K2 + 4;
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
extern __shared__ half2 tile_Q[];
@@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// Preload mask and K data for first iteration when using cp_async with multiple stages:
if constexpr (nstages > 1) {
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
constexpr bool use_cp_async = true;
if (ncols2 > 1 || mask_h2) {
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
}
// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
@@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
constexpr int tile_stride = nbatch_combine + 4;
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
@@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#endif // NEW_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
__launch_bounds__(nwarps*WARP_SIZE, 1)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
@@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
return;
}
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
typedef fattn_mma_f16_config<DKQ, DV> c;
@@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q1 = nb01 / sizeof(float2);
const int stride_Q2 = nb02 / sizeof(float2);
const int stride_K = nb11 / sizeof(half2);
const int stride_V = nb21 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half2);
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
@@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
@@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
@@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
typedef fattn_mma_f16_config<DKQ, DV> c;
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
constexpr int ncols = ncols1 * ncols2;
@@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
constexpr bool mla = DKQ == 576;
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
static_assert(DV % tile_A::J == 0, "bad DV");
static_assert(ncols % cols_per_warp == 0, "bad ncols");
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
@@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+2 -1
View File
@@ -10,6 +10,7 @@
template <int DKQ, int DV, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_tensor * Q = dst->src[0];
if constexpr (ncols2 <= 8) {
@@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
return;
}
if (Q->ne[1] <= 32/ncols2) {
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
return;
}
+1 -1
View File
@@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
if (!new_mma_available(cc)) {
return false;
}
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
+2
View File
@@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s13 = src1->nb[3] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
@@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s13 = src1->nb[2] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
CUDA_CHECK(cudaGetLastError());
}
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+7 -6
View File
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
if (i0 >= ne0) {
return;
}
const int64_t i1 = blockIdx.y;
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
// Load 4 floats per thread and calculate max. abs. value between them:
@@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
switch (mmq_get_q8_1_ds_layout(type_src0)) {
case MMQ_Q8_1_DS_LAYOUT_D4:
+26 -22
View File
@@ -49,34 +49,38 @@ endif()
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
# Link against oneDNN
find_package(DNNL)
set(GGML_SYCL_DNNL 0)
if(DNNL_FOUND)
if (NOT DEFINED DNNL_GPU_VENDOR)
# default to intel target
set(DNNL_GPU_VENDOR "INTEL")
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
if(GGML_SYCL_DNN)
find_package(DNNL)
if(DNNL_FOUND)
if (NOT DEFINED DNNL_GPU_VENDOR)
# default to intel target
set(DNNL_GPU_VENDOR "INTEL")
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
endif()
endif()
endif()
# Verify oneDNN was compiled for the same target as llama
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
set(GGML_SYCL_DNNL 1)
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
foreach(CONFIG ${CONFIGS})
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
message(STATUS "Found oneDNN: ${DNNL_LIB}")
endforeach()
# Verify oneDNN was compiled for the same target as llama
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
set(GGML_SYCL_DNNL 1)
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
foreach(CONFIG ${CONFIGS})
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
message(STATUS "Found oneDNN: ${DNNL_LIB}")
endforeach()
else()
message(WARNING
"oneDNN must be compiled for the same target as llama.cpp.
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
Disabling oneDNN support.")
endif()
else()
message(WARNING
"oneDNN must be compiled for the same target as llama.cpp.
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
Disabling oneDNN support.")
message(STATUS "oneDNN not found, disabling oneDNN support")
endif()
else()
message(STATUS "oneDNN not found, disabling oneDNN support")
message(STATUS "oneDNN support disabled by the user")
endif()
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
+37 -8
View File
@@ -32,16 +32,36 @@ public:
else static_assert(0);
}
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
// matrix A has m rows, k columns
// matrix B has k rows, n columns
// nra - number of elements to skip when moving into next row in A
// nrb - number of elements to skip when moving into next row in B
// nca - number of elements to skip when moving into next column in A
// ncb - number of elements to skip when moving into next column in B
// stride_a - number of elements to skip when moving to next A matrix
// stride_b - number of elements to skip when moving to next B matrix
// batches_a - number of A matrices
// batches_b - number of B matrices
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
auto stream = ctx.stream_dnnl(q);
auto eng = ctx.engine_dnnl(q);
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
// { # strides, # rows, # columns }
dnnl::memory::dims a_dims = { batches_a, m, k };
dnnl::memory::dims b_dims = { batches_b, k, n };
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
dnnl::memory::dims a_strides = { stride_a, nra, nca };
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
@@ -63,6 +83,15 @@ public:
matmul_prim.execute(stream, matmul_args);
}
// matrices A and B are column major, both having k rows
// matrix A has m column, matrix B has n columns
// output: column major matrix C = A transposed * B
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
}
};
#endif
+125 -65
View File
@@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0;
int g_ggml_sycl_disable_dnn = 0;
int g_ggml_sycl_prioritize_dmmv = 0;
static ggml_sycl_device_info ggml_sycl_init() {
@@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
#ifdef GGML_SYCL_GRAPH
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
#endif
#if GGML_SYCL_DNNL
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
#endif
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ)
@@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne00 == ne10);
const int64_t row_diff = row_high - row_low;
int id;
SYCL_CHECK(
CHECK_TRY_ERROR(id = get_current_device_id()));
#if !GGML_SYCL_DNNL
const int64_t ne0 = dst->ne[0];
const int64_t ne0 = dst->ne[0]; // used by MKL only
// the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into
int ldc = id == ctx.device ? ne0 : row_diff;
#endif
int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
#ifdef GGML_SYCL_F16
bool use_fp16 = true; // TODO(Yu) SYCL capability check
@@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
: src1_as_f16.get();
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
#if !GGML_SYCL_DNNL
const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
dst_f16.get(), dpct::library_data_t::real_half, ldc,
dpct::library_data_t::real_half)));
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
}
else
#endif
{
const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
dst_f16.get(), dpct::library_data_t::real_half, ldc,
dpct::library_data_t::real_half)));
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
}
else {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
#if !GGML_SYCL_DNNL
const float alpha = 1.0f;
const float beta = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
#endif
{
const float alpha = 1.0f;
const float beta = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
}
}
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i);
@@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
@@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
@@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
@@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
}
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
char * dst_t = reinterpret_cast<char *>(dst_ddf);
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
@@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
GGML_ASSERT(ne10 == ne00);
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12 * ne13;
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
};
sycl::range<3> block_dims(1, ne12, ne13);
queue->submit([&](sycl::handler & cgh) {
const void ** ptrs_src_get = ptrs_src.get();
void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
if (r2 == 1 && r3 == 1) {
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
}
else {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
}
}
} else {
// iterate over batches from smaller set of matrices (matrix 0)
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
}
}
}
}
else
#endif
{
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12 * ne13;
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
sycl::range<3> block_dims(1, ne12, ne13);
queue->submit([&](sycl::handler & cgh) {
const void ** ptrs_src_get = ptrs_src.get();
void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
});
});
});
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
}
}
} catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
@@ -3713,7 +3772,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
return GGML_STATUS_SUCCESS;
}
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
model_sycl_graph.end_recording();
+4 -2
View File
@@ -1704,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
}
}
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io);
if (kv_self != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
kv_self->state_write(io);
}
return io.n_bytes();
}
+8
View File
@@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
void llama_kv_cache_unified::set_full() {
n = size;
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
// setting it to 0 is the simplest way to achieve that
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
head = 0;
}
llama_sbatch llama_kv_cache_unified::sbatch_init(
@@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
void llama_kv_cache_recurrent::set_full() {
n = size;
head = 0;
}
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
+4 -10
View File
@@ -171,11 +171,8 @@ public:
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
@@ -343,11 +340,8 @@ public:
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
+10 -5
View File
@@ -822,13 +822,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
mappings.reserve(files.size());
mmaps_used.reserve(files.size());
for (const auto & file : files) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
if (!reg) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
bool is_numa = false;
auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (dev) {
auto * reg = ggml_backend_dev_backend_reg(dev);
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
if (is_numa_fn) {
is_numa = is_numa_fn();
}
}
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa);
mmaps_used.emplace_back(mapping->size(), 0);
if (mlock_mmaps) {
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
+3 -4
View File
@@ -12218,6 +12218,9 @@ struct llm_build_granite : public llm_graph_context {
// inp_pos - built only if rope enabled
ggml_tensor * inp_pos = nullptr;
if (use_rope) {
inp_pos = build_inp_pos();
}
auto * inp_attn = build_attn_inp_kv_unified();
@@ -12260,10 +12263,6 @@ struct llm_build_granite : public llm_graph_context {
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
if (use_rope) {
if (!inp_pos) {
inp_pos = build_inp_pos();
}
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
+1
View File
@@ -144,6 +144,7 @@ endif()
llama_build_and_test(test-log.cpp)
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-regex-partial.cpp)
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
if (NOT WIN32)
+3 -1
View File
@@ -832,7 +832,9 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+288
View File
@@ -0,0 +1,288 @@
// Tests common_regex (esp. its partial final matches support).
#include "common.h"
#include "regex-partial.h"
#include <sstream>
#include <iostream>
#include <optional>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << " Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
struct test_case {
std::string pattern;
struct input_output {
std::string input;
common_regex_match output;
};
std::vector<input_output> inputs_outputs;
};
static std::string common_regex_match_type_name(common_regex_match_type type) {
switch (type) {
case COMMON_REGEX_MATCH_TYPE_NONE:
return "COMMON_REGEX_MATCH_TYPE_NONE";
case COMMON_REGEX_MATCH_TYPE_PARTIAL:
return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
case COMMON_REGEX_MATCH_TYPE_FULL:
return "COMMON_REGEX_MATCH_TYPE_FULL";
}
return "?";
}
static void test_regex() {
printf("[%s]\n", __func__);
auto test = [](const test_case & test_case) {
common_regex cr(test_case.pattern);
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
for (const auto & input_output : test_case.inputs_outputs) {
std::cout << " Input: " << input_output.input << '\n';
auto m = cr.search(input_output.input, 0);
if (m != input_output.output) {
auto match_to_str = [&](const std::optional<common_regex_match> & m) {
std::ostringstream ss;
if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
ss << "<no match>";
} else {
GGML_ASSERT(!input_output.output.groups.empty());
std::vector<std::string> parts;
for (const auto & g : m->groups) {
parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
}
ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
}
return ss.str();
};
std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
std::cout << " Got: " << match_to_str(m) << '\n';
std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
throw std::runtime_error("Test failed");
}
}
};
test({
"a",
{
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
{"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
}
});
test({
"abcd",
{
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"d", {}},
{"bcd", {}},
{"cde", {}},
{"cd", {}},
{"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
{"abbie", {}},
{"", {}},
}
});
test({
".*?ab",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
}
});
test({
"a.*?b",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
{"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"d", {}},
{"b", {}},
}
});
test({
"ab(?:cd){2,4}ef",
{
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abcde", {}},
{"abcdef", {}},
{"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
{"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
{"abcdcdcdcdcdef", {}},
{"abcde", {}},
{"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
}
});
test({
"a(?:rte| pure )fact",
{
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"fact", {}},
{"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
{"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
{"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
{"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
{"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
{"" , {}},
{"pure", {}},
{"pure fact", {}},
}
});
test({
"abc",
{
{" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
{" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
{"b", {}},
{"c", {}},
{"", {}},
}
});
test({
"(?:abc)?\\s*def",
{
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
{"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
{"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
{"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
{" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
{"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
}
});
test({
"a+b",
{
{"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
{"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
}
});
test({
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">", // match 5 (function name again)
{
{"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
{"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
{"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
{"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
{"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
{"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
{"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
{"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
{"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
{"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
{"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
}
});
}
static void test_regex_to_reversed_partial_regex() {
printf("[%s]\n", __func__);
assert_equals<std::string>(
"((?:(?:c)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("abc"));
assert_equals<std::string>(
"(a+)[\\s\\S]*",
regex_to_reversed_partial_regex("a+"));
assert_equals<std::string>(
"(a*)[\\s\\S]*",
regex_to_reversed_partial_regex("a*"));
assert_equals<std::string>(
"(a?)[\\s\\S]*",
regex_to_reversed_partial_regex("a?"));
assert_equals<std::string>(
"([a-z])[\\s\\S]*",
regex_to_reversed_partial_regex("[a-z]"));
assert_equals<std::string>(
"((?:\\w+)?[a-z])[\\s\\S]*",
regex_to_reversed_partial_regex("[a-z]\\w+"));
assert_equals<std::string>(
"((?:a|b))[\\s\\S]*",
regex_to_reversed_partial_regex("(?:a|b)"));
assert_equals<std::string>(
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("abcd"));
assert_equals<std::string>(
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
regex_to_reversed_partial_regex("a*b"));
assert_equals<std::string>(
"((?:(?:b)?a)?.*)[\\s\\S]*",
regex_to_reversed_partial_regex(".*?ab"));
assert_equals<std::string>(
"((?:(?:b)?.*)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a.*?b"));
assert_equals<std::string>(
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a(bc)d"));
assert_equals<std::string>(
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
regex_to_reversed_partial_regex("a(bc|de)"));
assert_equals<std::string>(
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
regex_to_reversed_partial_regex("ab{2,4}c"));
}
int main() {
test_regex_to_reversed_partial_regex();
test_regex();
std::cout << "All tests passed.\n";
}
+48 -16
View File
@@ -687,7 +687,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true;
break;
}
auto value = argv[i];
auto * value = argv[i];
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) {
// enumerate all the devices and add their buffer types to the list
@@ -719,7 +719,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
// memory leak present in the implementation
// over in arg.cpp. Acceptable because we
// only parse these args once in this program.
auto override_group = value;
auto * override_group = value;
if (value[override_group_span_len] == '\0') {
value = &value[override_group_span_len];
last_group = true;
@@ -730,7 +730,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
std::vector<llama_model_tensor_buft_override> group_tensor_buft_overrides{};
auto override_span_len = std::strcspn(override_group, ";");
while (override_span_len > 0) {
auto override = override_group;
auto * override = override_group;
if (override_group[override_span_len] != '\0') {
override_group[override_span_len] = '\0';
override_group = &override_group[override_span_len + 1];
@@ -743,9 +743,10 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
override[tensor_name_span_len] = '\0';
auto tensor_name = override;
auto buffer_type = &override[tensor_name_span_len + 1];
auto * tensor_name = override;
auto * buffer_type = &override[tensor_name_span_len + 1];
if (buft_list.find(buffer_type) == buft_list.end()) {
printf("error: unrecognized buffer type '%s'\n", buffer_type);
printf("Available buffer types:\n");
for (const auto & it : buft_list) {
printf(" %s\n", ggml_backend_buft_name(it.second));
@@ -1736,7 +1737,7 @@ struct sql_printer : public printer {
}
};
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
@@ -1753,14 +1754,19 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
int res = llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
if (res != 0) {
fprintf(stderr, "%s: failed to decode prompt batch, res = %d\n", __func__, res);
return false;
}
n_processed += n_tokens;
}
llama_synchronize(ctx);
return true;
}
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
@@ -1770,10 +1776,15 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1));
int res = llama_decode(ctx, llama_batch_get_one(&token, 1));
if (res != 0) {
fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res);
return false;
}
llama_synchronize(ctx);
token = std::rand() % n_vocab;
}
return true;
}
static void llama_null_log_callback(enum ggml_log_level level, const char * text, void * user_data) {
@@ -1816,10 +1827,11 @@ int main(int argc, char ** argv) {
fprintf(stderr, "warning: sanitizer enabled, performance may be affected\n");
#endif
cmd_params params = parse_cmd_params(argc, argv);
// initialize backends
ggml_backend_load_all();
cmd_params params = parse_cmd_params(argc, argv);
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
fprintf(stderr, "%s: error: CPU backend is not loaded\n", __func__);
@@ -1917,13 +1929,21 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
}
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
exit(1);
}
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
}
test_gen(ctx, 1, t.n_threads);
bool res = test_gen(ctx, 1, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
exit(1);
}
}
for (int i = 0; i < params.reps; i++) {
@@ -1934,7 +1954,11 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
exit(1);
}
}
uint64_t t_start = get_time_ns();
@@ -1944,14 +1968,22 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
exit(1);
}
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_gen(ctx, t.n_gen, t.n_threads);
bool res = test_gen(ctx, t.n_gen, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run gen\n", __func__);
exit(1);
}
}
uint64_t t_ns = get_time_ns() - t_start;
Binary file not shown.
+1 -1
View File
@@ -1429,7 +1429,7 @@ struct server_slot {
pos = text.find(word, from_pos);
} else {
// otherwise, partial stop
pos = find_partial_stop_string(word, text);
pos = string_find_partial_stop(text, word);
}
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
+49
View File
@@ -0,0 +1,49 @@
#!/usr/bin/env python
import pytest
# ensure grandparent path is in sys.path
from pathlib import Path
import sys
from unit.test_tool_call import TEST_TOOL
path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))
import datetime
from utils import *
server: ServerProcess
TIMEOUT_SERVER_START = 15*60
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2"
server.server_port = 8081
server.n_slots = 1
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
@pytest.mark.parametrize("template_name,format", [
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
])
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "user", "content": "What is today?"},
],
"tools": tools,
})
assert res.status_code == 200
prompt = res.body["prompt"]
today_str = datetime.date.today().strftime(format)
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
+1 -1
View File
@@ -109,7 +109,7 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
])
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
global server
n_predict = 512
n_predict = 1024
# server = ServerPreset.stories15m_moe()
server.jinja = True
server.n_predict = n_predict
+12
View File
@@ -643,6 +643,18 @@ static json oaicompat_completion_params_parse(
throw std::runtime_error("Expected 'messages' to be an array");
}
for (auto & msg : messages) {
std::string role = json_value(msg, "role", std::string());
if (role != "assistant" && !msg.contains("content")) {
throw std::runtime_error("All non-assistant messages must contain 'content'");
}
if (role == "assistant") {
if (!msg.contains("content") && !msg.contains("tool_calls")) {
throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
}
if (!msg.contains("content")) {
continue; // avoid errors with no content
}
}
json & content = msg.at("content");
if (content.is_string() || content.is_null()) {
continue;
+272 -6
View File
@@ -18,6 +18,7 @@
"dexie": "^4.0.11",
"highlight.js": "^11.10.0",
"katex": "^0.16.15",
"pdfjs-dist": "^5.2.133",
"postcss": "^8.4.49",
"react": "^18.3.1",
"react-dom": "^18.3.1",
@@ -988,7 +989,7 @@
"version": "0.3.8",
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz",
"integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==",
"dev": true,
"devOptional": true,
"license": "MIT",
"dependencies": {
"@jridgewell/set-array": "^1.2.1",
@@ -1003,7 +1004,7 @@
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz",
"integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==",
"dev": true,
"devOptional": true,
"license": "MIT",
"engines": {
"node": ">=6.0.0"
@@ -1013,30 +1014,224 @@
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz",
"integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==",
"dev": true,
"devOptional": true,
"license": "MIT",
"engines": {
"node": ">=6.0.0"
}
},
"node_modules/@jridgewell/source-map": {
"version": "0.3.6",
"resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz",
"integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==",
"license": "MIT",
"optional": true,
"peer": true,
"dependencies": {
"@jridgewell/gen-mapping": "^0.3.5",
"@jridgewell/trace-mapping": "^0.3.25"
}
},
"node_modules/@jridgewell/sourcemap-codec": {
"version": "1.5.0",
"resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz",
"integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==",
"dev": true,
"devOptional": true,
"license": "MIT"
},
"node_modules/@jridgewell/trace-mapping": {
"version": "0.3.25",
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz",
"integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==",
"dev": true,
"devOptional": true,
"license": "MIT",
"dependencies": {
"@jridgewell/resolve-uri": "^3.1.0",
"@jridgewell/sourcemap-codec": "^1.4.14"
}
},
"node_modules/@napi-rs/canvas": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.70.tgz",
"integrity": "sha512-nD6NGa4JbNYSZYsTnLGrqe9Kn/lCkA4ybXt8sx5ojDqZjr2i0TWAHxx/vhgfjX+i3hCdKWufxYwi7CfXqtITSA==",
"license": "MIT",
"optional": true,
"engines": {
"node": ">= 10"
},
"optionalDependencies": {
"@napi-rs/canvas-android-arm64": "0.1.70",
"@napi-rs/canvas-darwin-arm64": "0.1.70",
"@napi-rs/canvas-darwin-x64": "0.1.70",
"@napi-rs/canvas-linux-arm-gnueabihf": "0.1.70",
"@napi-rs/canvas-linux-arm64-gnu": "0.1.70",
"@napi-rs/canvas-linux-arm64-musl": "0.1.70",
"@napi-rs/canvas-linux-riscv64-gnu": "0.1.70",
"@napi-rs/canvas-linux-x64-gnu": "0.1.70",
"@napi-rs/canvas-linux-x64-musl": "0.1.70",
"@napi-rs/canvas-win32-x64-msvc": "0.1.70"
}
},
"node_modules/@napi-rs/canvas-android-arm64": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.70.tgz",
"integrity": "sha512-I/YOuQ0wbkVYxVaYtCgN42WKTYxNqFA0gTcTrHIGG1jfpDSyZWII/uHcjOo4nzd19io6Y4+/BqP8E5hJgf9OmQ==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-darwin-arm64": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.70.tgz",
"integrity": "sha512-4pPGyXetHIHkw2TOJHujt3mkCP8LdDu8+CT15ld9Id39c752RcI0amDHSuMLMQfAjvusA9B5kKxazwjMGjEJpQ==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-darwin-x64": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.70.tgz",
"integrity": "sha512-+2N6Os9LbkmDMHL+raknrUcLQhsXzc5CSXRbXws9C3pv/mjHRVszQ9dhFUUe9FjfPhCJznO6USVdwOtu7pOrzQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-arm-gnueabihf": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.70.tgz",
"integrity": "sha512-QjscX9OaKq/990sVhSMj581xuqLgiaPVMjjYvWaCmAJRkNQ004QfoSMEm3FoTqM4DRoquP8jvuEXScVJsc1rqQ==",
"cpu": [
"arm"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-arm64-gnu": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.70.tgz",
"integrity": "sha512-LNakMOwwqwiHIwMpnMAbFRczQMQ7TkkMyATqFCOtUJNlE6LPP/QiUj/mlFrNbUn/hctqShJ60gWEb52ZTALbVw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-arm64-musl": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.70.tgz",
"integrity": "sha512-wBTOllEYNfJCHOdZj9v8gLzZ4oY3oyPX8MSRvaxPm/s7RfEXxCyZ8OhJ5xAyicsDdbE5YBZqdmaaeP5+xKxvtg==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-riscv64-gnu": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.70.tgz",
"integrity": "sha512-GVUUPC8TuuFqHip0rxHkUqArQnlzmlXmTEBuXAWdgCv85zTCFH8nOHk/YCF5yo0Z2eOm8nOi90aWs0leJ4OE5Q==",
"cpu": [
"riscv64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-x64-gnu": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.70.tgz",
"integrity": "sha512-/kvUa2lZRwGNyfznSn5t1ShWJnr/m5acSlhTV3eXECafObjl0VBuA1HJw0QrilLpb4Fe0VLywkpD1NsMoVDROQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-linux-x64-musl": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.70.tgz",
"integrity": "sha512-aqlv8MLpycoMKRmds7JWCfVwNf1fiZxaU7JwJs9/ExjTD8lX2KjsO7CTeAj5Cl4aEuzxUWbJPUUE2Qu9cZ1vfg==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@napi-rs/canvas-win32-x64-msvc": {
"version": "0.1.70",
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.70.tgz",
"integrity": "sha512-Q9QU3WIpwBTVHk4cPfBjGHGU4U0llQYRXgJtFtYqqGNEOKVN4OT6PQ+ve63xwIPODMpZ0HHyj/KLGc9CWc3EtQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@nodelib/fs.scandir": {
"version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
@@ -2002,7 +2197,7 @@
"version": "8.14.0",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz",
"integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==",
"dev": true,
"devOptional": true,
"license": "MIT",
"bin": {
"acorn": "bin/acorn"
@@ -2186,6 +2381,14 @@
"devOptional": true,
"license": "MIT/X11"
},
"node_modules/buffer-from": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz",
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
"license": "MIT",
"optional": true,
"peer": true
},
"node_modules/callsites": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
@@ -4843,6 +5046,18 @@
"node": ">=8"
}
},
"node_modules/pdfjs-dist": {
"version": "5.2.133",
"resolved": "https://registry.npmjs.org/pdfjs-dist/-/pdfjs-dist-5.2.133.tgz",
"integrity": "sha512-abE6ZWDxztt+gGFzfm4bX2ggfxUk9wsDEoFzIJm9LozaY3JdXR7jyLK4Bjs+XLXplCduuWS1wGhPC4tgTn/kzg==",
"license": "Apache-2.0",
"engines": {
"node": ">=20.16.0 || >=22.3.0"
},
"optionalDependencies": {
"@napi-rs/canvas": "^0.1.67"
}
},
"node_modules/picocolors": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz",
@@ -5753,6 +5968,17 @@
"node": ">=8"
}
},
"node_modules/source-map": {
"version": "0.6.1",
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
"integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==",
"license": "BSD-3-Clause",
"optional": true,
"peer": true,
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/source-map-js": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",
@@ -5762,6 +5988,18 @@
"node": ">=0.10.0"
}
},
"node_modules/source-map-support": {
"version": "0.5.21",
"resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz",
"integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==",
"license": "MIT",
"optional": true,
"peer": true,
"dependencies": {
"buffer-from": "^1.0.0",
"source-map": "^0.6.0"
}
},
"node_modules/space-separated-tokens": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz",
@@ -5859,6 +6097,34 @@
"node": ">=6"
}
},
"node_modules/terser": {
"version": "5.39.1",
"resolved": "https://registry.npmjs.org/terser/-/terser-5.39.1.tgz",
"integrity": "sha512-Mm6+uad0ZuDtcV8/4uOZQDQ8RuiC5Pu+iZRedJtF7yA/27sPL7d++In/AJKpWZlU3SYMPPkVfwetn6sgZ66pUA==",
"license": "BSD-2-Clause",
"optional": true,
"peer": true,
"dependencies": {
"@jridgewell/source-map": "^0.3.3",
"acorn": "^8.8.2",
"commander": "^2.20.0",
"source-map-support": "~0.5.20"
},
"bin": {
"terser": "bin/terser"
},
"engines": {
"node": ">=10"
}
},
"node_modules/terser/node_modules/commander": {
"version": "2.20.3",
"resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz",
"integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==",
"license": "MIT",
"optional": true,
"peer": true
},
"node_modules/textlinestream": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/textlinestream/-/textlinestream-1.1.1.tgz",
+1
View File
@@ -21,6 +21,7 @@
"dexie": "^4.0.11",
"highlight.js": "^11.10.0",
"katex": "^0.16.15",
"pdfjs-dist": "^5.2.133",
"postcss": "^8.4.49",
"react": "^18.3.1",
"react-dom": "^18.3.1",
+4
View File
@@ -16,6 +16,8 @@ export const CONFIG_DEFAULT = {
showTokensPerSecond: false,
showThoughtInProgress: false,
excludeThoughtOnReq: true,
pasteLongTextToFileLen: 2500,
pdfAsImage: false,
// make sure these default values are in sync with `common.h`
samplers: 'edkypmxt',
temperature: 0.8,
@@ -43,6 +45,8 @@ export const CONFIG_DEFAULT = {
export const CONFIG_INFO: Record<string, string> = {
apiKey: 'Set the API Key if you are using --api-key option for the server.',
systemMessage: 'The starting message that defines how model should behave.',
pasteLongTextToFileLen:
'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.',
samplers:
'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature',
temperature:
@@ -306,6 +306,7 @@ function ChatInput({
onStop: () => void;
isGenerating: boolean;
}) {
const { config } = useAppContext();
const [isDrag, setIsDrag] = useState(false);
return (
@@ -328,7 +329,28 @@ function ChatInput({
{({ getRootProps, getInputProps }) => (
<div
className="flex flex-col rounded-xl border-1 border-base-content/30 p-3 w-full"
// when a file is pasted to the input, we handle it here
// if a text is pasted, and if it is long text, we will convert it to a file
onPasteCapture={(e: ClipboardEvent<HTMLInputElement>) => {
const text = e.clipboardData.getData('text/plain');
if (
text.length > 0 &&
config.pasteLongTextToFileLen > 0 &&
text.length > config.pasteLongTextToFileLen
) {
// if the text is too long, we will convert it to a file
extraContext.addItems([
{
type: 'context',
name: 'Pasted Content',
content: text,
},
]);
e.preventDefault();
return;
}
// if a file is pasted, we will handle it here
const files = Array.from(e.clipboardData.items)
.filter((item) => item.kind === 'file')
.map((item) => item.getAsFile())
@@ -100,6 +100,16 @@ const SETTING_SECTIONS: SettingSection[] = [
key,
}) as SettingFieldInput
),
{
type: SettingInputType.SHORT_INPUT,
label: 'Paste length to file',
key: 'pasteLongTextToFileLen',
},
{
type: SettingInputType.CHECKBOX,
label: 'Parse PDF as image instead of text',
key: 'pdfAsImage',
},
],
},
{
@@ -452,10 +462,10 @@ function SettingsModalLongInput({
label?: string;
}) {
return (
<label className="form-control mb-2">
<div className="label inline">{label || configKey}</div>
<label className="form-control">
<div className="label inline text-sm">{label || configKey}</div>
<textarea
className="textarea textarea-bordered h-24"
className="textarea textarea-bordered h-24 mb-2"
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
value={value}
onChange={(e) => onChange(e.target.value)}
@@ -482,9 +492,7 @@ function SettingsModalShortInput({
<>
{/* on mobile, we simply show the help message here */}
{helpMsg && (
<div className="block md:hidden mb-1">
<b>{label || configKey}</b>
<br />
<div className="block mb-1 opacity-75">
<p className="text-xs">{helpMsg}</p>
</div>
)}
@@ -493,11 +501,6 @@ function SettingsModalShortInput({
<div tabIndex={0} role="button" className="font-bold hidden md:block">
{label || configKey}
</div>
{helpMsg && (
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
{helpMsg}
</div>
)}
</div>
<input
type="text"
@@ -2,6 +2,17 @@ import { useState } from 'react';
import { MessageExtra } from '../utils/types';
import toast from 'react-hot-toast';
import { useAppContext } from '../utils/app.context';
import * as pdfjs from 'pdfjs-dist';
import pdfjsWorkerSrc from 'pdfjs-dist/build/pdf.worker.min.mjs?url';
import { TextContent, TextItem } from 'pdfjs-dist/types/src/display/api';
pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
// This file handles uploading extra context items (a.k.a files)
// It allows processing these kinds of files:
// - image files (converted to base64)
// - text files (including code files)
// - pdf (converted to text)
// Interface describing the API returned by the hook
export interface ChatExtraContextApi {
@@ -13,7 +24,7 @@ export interface ChatExtraContextApi {
}
export function useChatExtraContext(): ChatExtraContextApi {
const { serverProps } = useAppContext();
const { serverProps, config } = useAppContext();
const [items, setItems] = useState<MessageExtra[]>([]);
const addItems = (newItems: MessageExtra[]) => {
@@ -28,6 +39,8 @@ export function useChatExtraContext(): ChatExtraContextApi {
setItems([]);
};
const isSupportVision = serverProps?.modalities?.vision;
const onFileAdded = (files: File[]) => {
for (const file of files) {
const mimeType = file.type;
@@ -38,7 +51,7 @@ export function useChatExtraContext(): ChatExtraContextApi {
}
if (mimeType.startsWith('image/')) {
if (!serverProps?.modalities?.vision) {
if (!isSupportVision) {
toast.error('Multimodal is not supported by this server or model.');
break;
}
@@ -69,7 +82,43 @@ export function useChatExtraContext(): ChatExtraContextApi {
toast.error('Video and audio files are not supported yet.');
break;
} else if (mimeType.startsWith('application/pdf')) {
toast.error('PDF files are not supported yet.');
if (config.pdfAsImage && !isSupportVision) {
toast(
'Multimodal is not supported, PDF will be converted to text instead of image.'
);
break;
}
const promise =
config.pdfAsImage && isSupportVision
? convertPDFToImage(file).then((base64Urls) => {
addItems(
base64Urls.map((base64Url) => ({
type: 'imageFile',
name: file.name,
base64Url,
}))
);
})
: convertPDFToText(file).then((content) => {
if (isSupportVision) {
toast.success(
'PDF file converted to text. You can also convert it to image, see in Settings.'
);
}
addItems([
{
type: 'textFile',
name: file.name,
content,
},
]);
});
promise.catch((error) => {
console.error(error);
toast.error('Failed to parse PDF file.');
});
break;
} else {
// Because there can be many text file types (like code file), we will not check the mime type
@@ -105,11 +154,69 @@ export function useChatExtraContext(): ChatExtraContextApi {
};
}
async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (event) => {
if (event.target?.result) {
resolve(event.target.result as ArrayBuffer);
} else {
reject(new Error('Failed to read file.'));
}
};
reader.readAsArrayBuffer(file);
});
}
async function convertPDFToText(file: File): Promise<string> {
const buffer = await getFileAsBuffer(file);
const pdf = await pdfjs.getDocument(buffer).promise;
const numPages = pdf.numPages;
const textContentPromises: Promise<TextContent>[] = [];
for (let i = 1; i <= numPages; i++) {
textContentPromises.push(
pdf.getPage(i).then((page) => page.getTextContent())
);
}
const textContents = await Promise.all(textContentPromises);
const textItems = textContents.flatMap((textContent: TextContent) =>
textContent.items.map((item) => (item as TextItem).str ?? '')
);
return textItems.join('\n');
}
// returns list of base64 images
async function convertPDFToImage(file: File): Promise<string[]> {
const buffer = await getFileAsBuffer(file);
const doc = await pdfjs.getDocument(buffer).promise;
const pages: Promise<string>[] = [];
for (let i = 1; i <= doc.numPages; i++) {
const page = await doc.getPage(i);
const viewport = page.getViewport({ scale: 1.5 });
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
canvas.width = viewport.width;
canvas.height = viewport.height;
if (!ctx) {
throw new Error('Failed to get 2D context from canvas');
}
const task = page.render({ canvasContext: ctx, viewport: viewport });
pages.push(
task.promise.then(() => {
return canvas.toDataURL();
})
);
}
return await Promise.all(pages);
}
// WARN: vibe code below
// This code is a heuristic to determine if a string is likely not binary.
// It is necessary because input file can have various mime types which we don't have time to investigate.
// For example, a python file can be text/plain, application/x-python, etc.
export function isLikelyNotBinary(str: string): boolean {
function isLikelyNotBinary(str: string): boolean {
const options = {
prefixLength: 1024 * 10, // Check the first 10KB of the string
suspiciousCharThresholdRatio: 0.15, // Allow up to 15% suspicious chars
+1 -1
View File
@@ -7,7 +7,7 @@ import * as fflate from 'fflate';
/* eslint-disable */
const MAX_BUNDLE_SIZE = 1.5 * 1024 * 1024; // only increase when absolutely necessary
const MAX_BUNDLE_SIZE = 2 * 1024 * 1024; // only increase when absolutely necessary
const GUIDE_FOR_FRONTEND = `
<!--