forked from wylab/llama.cpp
Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5364ae4ba5 | |||
| 7c07ac244d | |||
| 0a338ed013 | |||
| bc098c3cf0 | |||
| c6a2c9e741 | |||
| 07ad2b6db3 | |||
| c531edfa34 | |||
| 02cdd2d8b0 | |||
| 64bb51cf90 | |||
| 9c404ed54c | |||
| 6c8b91500e | |||
| 3cc1f1f1d2 | |||
| c753d7bed0 | |||
| b2838049cc | |||
| aa48e373f2 | |||
| e3a9421b78 | |||
| 5ab5d5fb25 | |||
| 3198405e98 | |||
| f5170c1d7a | |||
| 017f10b5fa | |||
| 4696d56749 | |||
| b7d2672082 | |||
| 6da34fa276 | |||
| 5e7d95e22e | |||
| 053174436f | |||
| 360a9c98e1 | |||
| 09d13d94fb | |||
| 24e86cae72 | |||
| bb1681fbd5 | |||
| d486dd3e8e | |||
| 21ca987fba | |||
| be1d4a13db | |||
| ab3971f2a0 | |||
| e5c834f718 | |||
| 71bdbdb587 | |||
| f0995d28ce | |||
| c252e0c409 |
@@ -140,3 +140,94 @@ jobs:
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-24-ppc64el-cpu-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup PowerPC64le
|
||||
run: |
|
||||
sudo dpkg --add-architecture ppc64el
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
EOF
|
||||
|
||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
gcc-14-powerpc64le-linux-gnu \
|
||||
g++-14-powerpc64le-linux-gnu \
|
||||
libcurl4-openssl-dev:ppc64el
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=ppc64 \
|
||||
-DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-24-ppc64el-vulkan-cross:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup PowerPC64le
|
||||
run: |
|
||||
sudo dpkg --add-architecture ppc64el
|
||||
|
||||
# Add arch-specific repositories for non-amd64 architectures
|
||||
cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe
|
||||
deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe
|
||||
EOF
|
||||
|
||||
sudo apt-get update || true ;# Prevent failure due to missing URLs.
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
glslc \
|
||||
gcc-14-powerpc64le-linux-gnu \
|
||||
g++-14-powerpc64le-linux-gnu \
|
||||
libvulkan-dev:ppc64el \
|
||||
libcurl4-openssl-dev:ppc64el
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_VULKAN=ON \
|
||||
-DGGML_OPENMP=OFF \
|
||||
-DLLAMA_BUILD_EXAMPLES=ON \
|
||||
-DLLAMA_BUILD_TOOLS=ON \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=ppc64 \
|
||||
-DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \
|
||||
-DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON \
|
||||
-DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \
|
||||
-DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
@@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
|
||||
minja/minja.hpp
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
regex-partial.cpp
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
speculative.cpp
|
||||
|
||||
+130
-110
@@ -6,6 +6,15 @@
|
||||
|
||||
#include <optional>
|
||||
|
||||
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
|
||||
auto time = std::chrono::system_clock::to_time_t(now);
|
||||
auto local_time = *std::localtime(&time);
|
||||
std::ostringstream ss;
|
||||
ss << std::put_time(&local_time, format.c_str());
|
||||
auto res = ss.str();
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
|
||||
struct common_chat_templates {
|
||||
@@ -24,6 +33,7 @@ struct templates_params {
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = true;
|
||||
bool extract_reasoning = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
||||
@@ -939,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||
auto builtin_tools = json::array();
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
if (!inputs.tools.is_null()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "python" || name == "code_interpreter") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||
expect_tool_parameters(name, parameters, {"code"});
|
||||
} else {
|
||||
return false;
|
||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||
expect_tool_parameters(name, parameters, {"query"});
|
||||
} else if (name == "python" || name == "code_interpreter") {
|
||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
||||
expect_tool_parameters(name, parameters, {"code"});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> kvs;
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
||||
}
|
||||
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||
builtin_tools.push_back(name);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
if (allow_python_tag_builtin_tools) {
|
||||
handle_builtin_tool(name, parameters);
|
||||
}
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" space "
|
||||
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
||||
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
||||
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"}\" space"));
|
||||
});
|
||||
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
||||
});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
|
||||
std::vector<std::string> kvs;
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
||||
}
|
||||
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
||||
builtin_tools.push_back(name);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
||||
if (allow_python_tag_builtin_tools) {
|
||||
handle_builtin_tool(name, parameters);
|
||||
}
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"{\" space "
|
||||
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
||||
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
||||
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
||||
"\"}\" space"));
|
||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
});
|
||||
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
||||
});
|
||||
if (!builtin_tools.empty()) {
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
});
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"date_string", format_time(inputs.now, "%d %b %Y")},
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||
@@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
@@ -1285,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_params data;
|
||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & parameters = function.at("parameters");
|
||||
std::string name = function.at("name");
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
}
|
||||
has_raw_python = true;
|
||||
const auto & type = parameters.at("type");
|
||||
if (type == "object") {
|
||||
auto properties = parameters.at("properties");
|
||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||
if (it.value().at("type") == "string") {
|
||||
if (!python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||
if (!inputs.tools.is_null()) {
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
const auto & parameters = function.at("parameters");
|
||||
std::string name = function.at("name");
|
||||
if (name == "python" || name == "ipython") {
|
||||
if (!parameters.contains("type")) {
|
||||
throw std::runtime_error("Missing type in python tool");
|
||||
}
|
||||
has_raw_python = true;
|
||||
const auto & type = parameters.at("type");
|
||||
if (type == "object") {
|
||||
auto properties = parameters.at("properties");
|
||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||
if (it.value().at("type") == "string") {
|
||||
if (!python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("Multiple string arguments found in python tool");
|
||||
}
|
||||
python_code_argument_name = it.key();
|
||||
}
|
||||
python_code_argument_name = it.key();
|
||||
}
|
||||
if (python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("No string argument found in python tool");
|
||||
}
|
||||
} else if (type != "string") {
|
||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||
}
|
||||
if (python_code_argument_name.empty()) {
|
||||
throw std::runtime_error("No string argument found in python tool");
|
||||
}
|
||||
} else if (type != "string") {
|
||||
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||
});
|
||||
if (has_raw_python) {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
||||
data.preserved_tokens.push_back("<|python_tag|>");
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||
});
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
// TODO: if (has_raw_python)
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
||||
@@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
params.extract_reasoning = inputs.extract_reasoning;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
@@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_firefunction_v2(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
// Functionary v3.1 (w/ tools)
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
||||
}
|
||||
|
||||
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
||||
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
||||
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
// Mistral Nemo (w/ tools)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
bool extract_reasoning = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
|
||||
@@ -443,6 +443,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const auto current_partial = stop.substr(0, char_index + 1);
|
||||
if (string_ends_with(str, current_partial)) {
|
||||
return str.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$0");
|
||||
|
||||
+4
-4
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
@@ -503,10 +504,9 @@ static bool string_starts_with(const std::string & str,
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
static bool string_ends_with(const std::string & str,
|
||||
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
void string_process_escapes(std::string & input);
|
||||
|
||||
@@ -13,10 +13,12 @@
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <exception>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -393,8 +395,8 @@ class chat_template {
|
||||
|
||||
for (const auto & message_ : adjusted_messages) {
|
||||
auto message = message_;
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
|
||||
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
|
||||
@@ -415,7 +417,6 @@ class chat_template {
|
||||
}
|
||||
}
|
||||
if (polyfill_tool_calls) {
|
||||
auto content = message.at("content");
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call.at("type") != "function") {
|
||||
@@ -434,8 +435,11 @@ class chat_template {
|
||||
auto obj = json {
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
if (!content.is_null() && !content.empty()) {
|
||||
obj["content"] = content;
|
||||
if (message.contains("content")) {
|
||||
auto content = message.at("content");
|
||||
if (!content.is_null() && !content.empty()) {
|
||||
obj["content"] = content;
|
||||
}
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("tool_calls");
|
||||
|
||||
+69
-36
@@ -11,6 +11,7 @@
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
@@ -233,7 +234,7 @@ public:
|
||||
}
|
||||
} else if (is_object()) {
|
||||
if (!index.is_hashable())
|
||||
throw std::runtime_error("Unashable type: " + index.dump());
|
||||
throw std::runtime_error("Unhashable type: " + index.dump());
|
||||
auto it = object_->find(index.primitive_);
|
||||
if (it == object_->end())
|
||||
throw std::runtime_error("Key not found: " + index.dump());
|
||||
@@ -252,7 +253,7 @@ public:
|
||||
auto index = key.get<int>();
|
||||
return array_->at(index < 0 ? array_->size() + index : index);
|
||||
} else if (object_) {
|
||||
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
|
||||
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
|
||||
auto it = object_->find(key.primitive_);
|
||||
if (it == object_->end()) return Value();
|
||||
return it->second;
|
||||
@@ -261,7 +262,7 @@ public:
|
||||
}
|
||||
void set(const Value& key, const Value& value) {
|
||||
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
|
||||
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
|
||||
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
|
||||
(*object_)[key.primitive_] = value;
|
||||
}
|
||||
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
|
||||
@@ -398,7 +399,7 @@ public:
|
||||
}
|
||||
return false;
|
||||
} else if (object_) {
|
||||
if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
|
||||
if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
|
||||
return object_->find(value.primitive_) != object_->end();
|
||||
} else {
|
||||
throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
|
||||
@@ -416,7 +417,7 @@ public:
|
||||
return const_cast<Value*>(this)->at(index);
|
||||
}
|
||||
Value& at(const Value & index) {
|
||||
if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
|
||||
if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
|
||||
if (is_array()) return array_->at(index.get<int>());
|
||||
if (is_object()) return object_->at(index.primitive_);
|
||||
throw std::runtime_error("Value is not an array or object: " + dump());
|
||||
@@ -676,8 +677,8 @@ public:
|
||||
class VariableExpr : public Expression {
|
||||
std::string name;
|
||||
public:
|
||||
VariableExpr(const Location & location, const std::string& n)
|
||||
: Expression(location), name(n) {}
|
||||
VariableExpr(const Location & loc, const std::string& n)
|
||||
: Expression(loc), name(n) {}
|
||||
std::string get_name() const { return name; }
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
if (!context->contains(name)) {
|
||||
@@ -1200,9 +1201,9 @@ public:
|
||||
|
||||
class SliceExpr : public Expression {
|
||||
public:
|
||||
std::shared_ptr<Expression> start, end;
|
||||
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
|
||||
: Expression(loc), start(std::move(s)), end(std::move(e)) {}
|
||||
std::shared_ptr<Expression> start, end, step;
|
||||
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
|
||||
: Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> &) const override {
|
||||
throw std::runtime_error("SliceExpr not implemented");
|
||||
}
|
||||
@@ -1219,18 +1220,35 @@ public:
|
||||
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
|
||||
auto target_value = base->evaluate(context);
|
||||
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
|
||||
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
|
||||
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
|
||||
auto len = target_value.size();
|
||||
auto wrap = [len](int64_t i) -> int64_t {
|
||||
if (i < 0) {
|
||||
return i + len;
|
||||
}
|
||||
return i;
|
||||
};
|
||||
int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
|
||||
if (!step) {
|
||||
throw std::runtime_error("slice step cannot be zero");
|
||||
}
|
||||
int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
|
||||
int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
|
||||
if (target_value.is_string()) {
|
||||
std::string s = target_value.get<std::string>();
|
||||
if (start < 0) start = s.size() + start;
|
||||
if (end < 0) end = s.size() + end;
|
||||
return s.substr(start, end - start);
|
||||
|
||||
std::string result;
|
||||
if (start < end && step == 1) {
|
||||
result = s.substr(start, end - start);
|
||||
} else {
|
||||
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
|
||||
result += s[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
||||
} else if (target_value.is_array()) {
|
||||
if (start < 0) start = target_value.size() + start;
|
||||
if (end < 0) end = target_value.size() + end;
|
||||
auto result = Value::array();
|
||||
for (auto i = start; i < end; ++i) {
|
||||
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
|
||||
result.push_back(target_value.at(i));
|
||||
}
|
||||
return result;
|
||||
@@ -1305,6 +1323,8 @@ public:
|
||||
if (name == "iterable") return l.is_iterable();
|
||||
if (name == "sequence") return l.is_array();
|
||||
if (name == "defined") return !l.is_null();
|
||||
if (name == "true") return l.to_bool();
|
||||
if (name == "false") return !l.to_bool();
|
||||
throw std::runtime_error("Unknown type for 'is' operator: " + name);
|
||||
};
|
||||
auto value = eval();
|
||||
@@ -1520,6 +1540,10 @@ public:
|
||||
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
|
||||
auto suffix = vargs.args[0].get<std::string>();
|
||||
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
|
||||
} else if (method->get_name() == "startswith") {
|
||||
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
|
||||
auto prefix = vargs.args[0].get<std::string>();
|
||||
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
|
||||
} else if (method->get_name() == "title") {
|
||||
vargs.expectArgs("title method", {0, 0}, {0, 0});
|
||||
auto res = str;
|
||||
@@ -2082,28 +2106,37 @@ private:
|
||||
|
||||
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
|
||||
if (!consumeToken("[").empty()) {
|
||||
std::shared_ptr<Expression> index;
|
||||
std::shared_ptr<Expression> index;
|
||||
auto slice_loc = get_location();
|
||||
std::shared_ptr<Expression> start, end, step;
|
||||
bool has_first_colon = false, has_second_colon = false;
|
||||
|
||||
if (!peekSymbols({ ":" })) {
|
||||
start = parseExpression();
|
||||
}
|
||||
|
||||
if (!consumeToken(":").empty()) {
|
||||
has_first_colon = true;
|
||||
if (!peekSymbols({ ":", "]" })) {
|
||||
end = parseExpression();
|
||||
}
|
||||
if (!consumeToken(":").empty()) {
|
||||
auto slice_end = parseExpression();
|
||||
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
|
||||
} else {
|
||||
auto slice_start = parseExpression();
|
||||
if (!consumeToken(":").empty()) {
|
||||
consumeSpaces();
|
||||
if (peekSymbols({ "]" })) {
|
||||
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
|
||||
} else {
|
||||
auto slice_end = parseExpression();
|
||||
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
|
||||
}
|
||||
} else {
|
||||
index = std::move(slice_start);
|
||||
has_second_colon = true;
|
||||
if (!peekSymbols({ "]" })) {
|
||||
step = parseExpression();
|
||||
}
|
||||
}
|
||||
if (!index) throw std::runtime_error("Empty index in subscript");
|
||||
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
|
||||
}
|
||||
|
||||
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
|
||||
if ((has_first_colon || has_second_colon) && (start || end || step)) {
|
||||
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
|
||||
} else {
|
||||
index = std::move(start);
|
||||
}
|
||||
if (!index) throw std::runtime_error("Empty index in subscript");
|
||||
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
|
||||
|
||||
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
|
||||
} else if (!consumeToken(".").empty()) {
|
||||
auto identifier = parseIdentifier();
|
||||
if (!identifier) throw std::runtime_error("Expected identifier in subscript");
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
#include "regex-partial.h"
|
||||
#include "common.h"
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
|
||||
common_regex::common_regex(const std::string & pattern) :
|
||||
pattern(pattern),
|
||||
rx(pattern),
|
||||
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
|
||||
|
||||
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
|
||||
std::smatch match;
|
||||
if (pos > input.size()) {
|
||||
throw std::runtime_error("Position out of bounds");
|
||||
}
|
||||
auto start = input.begin() + pos;
|
||||
auto found = as_match
|
||||
? std::regex_match(start, input.end(), match, rx)
|
||||
: std::regex_search(start, input.end(), match, rx);
|
||||
if (found) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
|
||||
for (size_t i = 0; i < match.size(); ++i) {
|
||||
auto begin = pos + match.position(i);
|
||||
res.groups.emplace_back(begin, begin + match.length(i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
|
||||
auto group = srmatch[1].str();
|
||||
if (group.length() != 0) {
|
||||
auto it = srmatch[1].second.base();
|
||||
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
|
||||
if ((!as_match) || it == input.begin()) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
|
||||
const size_t begin = std::distance(input.begin(), it);
|
||||
const size_t end = input.size();
|
||||
if (begin == std::string::npos || end == std::string::npos || begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
res.groups.push_back({begin, end});
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
|
||||
|
||||
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
|
||||
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||
|
||||
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
|
||||
- /a|b/ -> (a|b).*
|
||||
- /a*?/ -> error, could match ""
|
||||
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
|
||||
- /.*?ab/ -> ((?:b)?a).* (merge .*)
|
||||
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
|
||||
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
|
||||
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
|
||||
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
|
||||
|
||||
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
|
||||
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
|
||||
*/
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||
auto it = pattern.begin();
|
||||
const auto end = pattern.end();
|
||||
|
||||
std::function<std::string()> process = [&]() {
|
||||
std::vector<std::vector<std::string>> alternatives(1);
|
||||
std::vector<std::string> * sequence = &alternatives.back();
|
||||
|
||||
while (it != end) {
|
||||
if (*it == '[') {
|
||||
auto start = it;
|
||||
++it;
|
||||
while (it != end) {
|
||||
if ((*it == '\\') && (++it != end)) {
|
||||
++it;
|
||||
} else if ((it != end) && (*it == ']')) {
|
||||
break;
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '[' in pattern");
|
||||
}
|
||||
++it;
|
||||
sequence->push_back(std::string(start, it));
|
||||
} else if (*it == '*' || *it == '?' || *it == '+') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Quantifier without preceding element");
|
||||
}
|
||||
sequence->back() += *it;
|
||||
auto is_star = *it == '*';
|
||||
++it;
|
||||
if (is_star) {
|
||||
if (*it == '?') {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
} else if (*it == '{') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Repetition without preceding element");
|
||||
}
|
||||
++it;
|
||||
auto start = it;
|
||||
while (it != end && *it != '}') {
|
||||
++it;
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '{' in pattern");
|
||||
}
|
||||
auto parts = string_split(std::string(start, it), ",");
|
||||
++it;
|
||||
if (parts.size() > 2) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
|
||||
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
|
||||
if (s.empty()) {
|
||||
return def;
|
||||
}
|
||||
return std::stoi(s);
|
||||
};
|
||||
auto min = parseOptInt(parts[0], 0);
|
||||
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
|
||||
if (min && max && *max < *min) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
|
||||
auto part = sequence->back();
|
||||
sequence->pop_back();
|
||||
for (int i = 0; i < *min; i++) {
|
||||
sequence->push_back(part);
|
||||
}
|
||||
if (max) {
|
||||
for (int i = *min; i < *max; i++) {
|
||||
sequence->push_back(part + "?");
|
||||
}
|
||||
} else {
|
||||
sequence->push_back(part + "*");
|
||||
}
|
||||
} else if (*it == '(') {
|
||||
++it;
|
||||
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
|
||||
it += 2;
|
||||
}
|
||||
auto sub = process();
|
||||
if (*it != ')') {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
++it;
|
||||
auto & part = sequence->emplace_back("(?:");
|
||||
part += sub;
|
||||
part += ")";
|
||||
} else if (*it == ')') {
|
||||
break;
|
||||
} else if (*it == '|') {
|
||||
++it;
|
||||
alternatives.emplace_back();
|
||||
sequence = &alternatives.back();
|
||||
} else if (*it == '\\' && (++it != end)) {
|
||||
auto str = std::string("\\") + *it;
|
||||
sequence->push_back(str);
|
||||
++it;
|
||||
} else if (it != end) {
|
||||
sequence->push_back(std::string(1, *it));
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
|
||||
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||
std::vector<std::string> res_alts;
|
||||
for (const auto & parts : alternatives) {
|
||||
auto & res = res_alts.emplace_back();
|
||||
for (size_t i = 0; i < parts.size() - 1; i++) {
|
||||
res += "(?:";
|
||||
}
|
||||
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
|
||||
res += *it;
|
||||
if (it != parts.rend() - 1) {
|
||||
res += ")?";
|
||||
}
|
||||
}
|
||||
}
|
||||
return string_join(res_alts, "|");
|
||||
};
|
||||
auto res = process();
|
||||
if (it != end) {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
|
||||
return "(" + res + ")[\\s\\S]*";
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
enum common_regex_match_type {
|
||||
COMMON_REGEX_MATCH_TYPE_NONE,
|
||||
COMMON_REGEX_MATCH_TYPE_PARTIAL,
|
||||
COMMON_REGEX_MATCH_TYPE_FULL,
|
||||
};
|
||||
|
||||
struct common_string_range {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
if (begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
}
|
||||
// prevent default ctor
|
||||
common_string_range() = delete;
|
||||
bool empty() const {
|
||||
return begin == end;
|
||||
}
|
||||
bool operator==(const common_string_range & other) const {
|
||||
return begin == other.begin && end == other.end;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_regex_match {
|
||||
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
||||
std::vector<common_string_range> groups;
|
||||
|
||||
bool operator==(const common_regex_match & other) const {
|
||||
return type == other.type && groups == other.groups;
|
||||
}
|
||||
bool operator!=(const common_regex_match & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
class common_regex {
|
||||
std::string pattern;
|
||||
std::regex rx;
|
||||
std::regex rx_reversed_partial;
|
||||
|
||||
public:
|
||||
explicit common_regex(const std::string & pattern);
|
||||
|
||||
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
|
||||
|
||||
const std::string & str() const { return pattern; }
|
||||
};
|
||||
|
||||
// For testing only (pretty print of failures).
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern);
|
||||
@@ -2069,6 +2069,9 @@ class Llama4Model(LlamaModel):
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.startswith("language_model."):
|
||||
name = name.replace("language_model.", "")
|
||||
|
||||
# split the gate_up into gate and up
|
||||
if "gate_up_proj" in name:
|
||||
name_up = name.replace("gate_up_proj", "up_proj.weight")
|
||||
|
||||
@@ -731,6 +731,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
|
||||
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
|
||||
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
|
||||
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
|
||||
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
|
||||
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
|
||||
|
||||
@@ -741,6 +742,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
||||
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
|
||||
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
|
||||
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
|
||||
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
|
||||
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
|
||||
|
||||
|
||||
|
||||
+1
-1
@@ -31,7 +31,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
|
||||
|
||||
## Pre-quantized models
|
||||
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default.
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/ggml-org
|
||||
|
||||
Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server`
|
||||
|
||||
|
||||
@@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
|
||||
option(GGML_SYCL "ggml: use SYCL" OFF)
|
||||
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
||||
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
||||
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
|
||||
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
||||
"ggml: sycl target device")
|
||||
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
||||
|
||||
@@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||
assert((nrc == 2) || (nrc == 1));
|
||||
#else
|
||||
assert(nrc == 1);
|
||||
#endif
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
@@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (nrc == 2) {
|
||||
const block_q6_K * GGML_RESTRICT x0 = x;
|
||||
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
||||
const block_q8_K * GGML_RESTRICT y0 = y;
|
||||
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
|
||||
|
||||
float32x4_t vfsum = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
|
||||
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
|
||||
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
|
||||
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
|
||||
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
|
||||
const int8_t * GGML_RESTRICT qy0 = y0->qs;
|
||||
const int8_t * GGML_RESTRICT qy1 = y1->qs;
|
||||
|
||||
const uint8x16_t mone = vdupq_n_u8(0x30);
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
int32x4_t visum = vdupq_n_s32(0);
|
||||
|
||||
// process 8 blocks per iteration, totally 16 blocks
|
||||
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
|
||||
int8x16_t vx0[8], vx1[8];
|
||||
|
||||
// de-quantize vx0[8]
|
||||
{
|
||||
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
|
||||
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
|
||||
|
||||
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
||||
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
||||
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
||||
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
||||
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
||||
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
||||
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
||||
|
||||
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
||||
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
||||
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
||||
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
||||
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
||||
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
||||
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
||||
}
|
||||
|
||||
// de-quantize vx1[8]
|
||||
{
|
||||
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
|
||||
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
|
||||
|
||||
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
||||
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
||||
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
||||
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
||||
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
||||
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
||||
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
||||
|
||||
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
||||
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
||||
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
||||
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
||||
|
||||
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
||||
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
||||
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
||||
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
||||
}
|
||||
|
||||
// process 16 elements (one block with same scale) per iteration
|
||||
// - vx = concat(ql, qh) - 32
|
||||
// - r1,r2,r3,r4 = smmla(vx, vy)
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const int blk = j * 8 + k;
|
||||
|
||||
const int8x16_t vy0 = vld1q_s8(qy0);
|
||||
const int8x16_t vy1 = vld1q_s8(qy1);
|
||||
qy0 += 16;
|
||||
qy1 += 16;
|
||||
|
||||
const int32x4_t block_scale = {
|
||||
x0->scales[blk],
|
||||
x0->scales[blk],
|
||||
x1->scales[blk],
|
||||
x1->scales[blk],
|
||||
};
|
||||
|
||||
// calculate four results at once with outer product
|
||||
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
||||
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
||||
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
||||
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
||||
int32x4_t vr = vdupq_n_s32(0);
|
||||
vr = vmmlaq_s32(vr, vx_l, vy_l);
|
||||
vr = vmmlaq_s32(vr, vx_h, vy_h);
|
||||
|
||||
// apply block scale, will NOT overflow
|
||||
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
|
||||
visum = vmlaq_s32(visum, vr, block_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// adjust bias, apply superblock scale
|
||||
{
|
||||
int32_t bias[4];
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
||||
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
||||
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
||||
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
||||
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
||||
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
||||
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
||||
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
||||
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
||||
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
||||
const svint64_t zero = svdup_n_s64(0);
|
||||
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
||||
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
||||
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
||||
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
||||
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
||||
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
||||
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
||||
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
||||
#else
|
||||
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
||||
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
||||
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
||||
|
||||
int8x16_t scales_s8 = vld1q_s8(x0->scales);
|
||||
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
||||
scales_s8 = vld1q_s8(x1->scales);
|
||||
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
||||
|
||||
int32x4_t prod;
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
|
||||
bias[0] = vaddvq_s32(prod);
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
|
||||
bias[1] = vaddvq_s32(prod);
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||
bias[2] = vaddvq_s32(prod);
|
||||
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
|
||||
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
|
||||
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
||||
bias[3] = vaddvq_s32(prod);
|
||||
|
||||
#endif
|
||||
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
||||
|
||||
const float32x4_t superblock_scale = {
|
||||
GGML_FP16_TO_FP32(x0->d) * y0->d,
|
||||
GGML_FP16_TO_FP32(x0->d) * y1->d,
|
||||
GGML_FP16_TO_FP32(x1->d) * y0->d,
|
||||
GGML_FP16_TO_FP32(x1->d) * y1->d,
|
||||
};
|
||||
|
||||
visum = vsubq_s32(visum, vibias);
|
||||
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// vfsum = ABCD -> ACBD
|
||||
// AC -> s, BD -> (s+bs)
|
||||
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
|
||||
vst1_f32(s, vget_low_f32 (vfsum));
|
||||
vst1_f32(s + bs, vget_high_f32(vfsum));
|
||||
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||
float sum = 0;
|
||||
|
||||
@@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
||||
.from_float = quantize_row_q6_K,
|
||||
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||
.nrows = 2,
|
||||
#else
|
||||
.nrows = 1,
|
||||
#endif
|
||||
},
|
||||
[GGML_TYPE_IQ2_XXS] = {
|
||||
.from_float = NULL,
|
||||
|
||||
@@ -678,10 +678,14 @@ void launch_fattn(
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const bool is_mla = DV == 512; // TODO better parameterization
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(V || is_mla);
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
@@ -689,6 +693,10 @@ void launch_fattn(
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
@@ -713,10 +721,10 @@ void launch_fattn(
|
||||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
const char * V_data = V ? (const char *) V->data : nullptr;
|
||||
size_t nb21 = V ? V->nb[1] : nb11;
|
||||
size_t nb22 = V ? V->nb[2] : nb12;
|
||||
size_t nb23 = V ? V->nb[3] : nb13;
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
||||
@@ -733,7 +741,7 @@ void launch_fattn(
|
||||
nb13 = nb13*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
|
||||
@@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 32;
|
||||
static constexpr int nbatch_V2 = 32;
|
||||
static constexpr int nbatch_combine = 32;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 40;
|
||||
static constexpr int nbatch_V2 = 40;
|
||||
static constexpr int nbatch_combine = 40;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 40;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 48;
|
||||
static constexpr int nbatch_V2 = 48;
|
||||
static constexpr int nbatch_combine = 48;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 48;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 56;
|
||||
static constexpr int nbatch_V2 = 56;
|
||||
static constexpr int nbatch_combine = 56;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 56;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 64;
|
||||
static constexpr int nbatch_V2 = 64;
|
||||
static constexpr int nbatch_combine = 64;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
static constexpr int nbatch_K2 = 128;
|
||||
static constexpr int nbatch_V2 = 128;
|
||||
static constexpr int nbatch_combine = 128;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
}
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
#else
|
||||
GGML_UNUSED(ncols);
|
||||
return 128;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
|
||||
static constexpr int nwarps_max = 8;
|
||||
static constexpr bool Q_in_reg = false;
|
||||
static constexpr int nstages_target = 1;
|
||||
static constexpr int nbatch_K2 = 160;
|
||||
static constexpr int nbatch_V2 = 128;
|
||||
static constexpr int nbatch_combine = 128;
|
||||
|
||||
static int get_nbatch_K2_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||
return ncols <= 16 ? 96 : 160;
|
||||
}
|
||||
return ncols <= 16 ? 288 : 160;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
return ncols <= 16 ? 96 : 160;
|
||||
#else
|
||||
return ncols <= 16 ? 288 : 160;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
}
|
||||
return ncols <= 16 ? 256 : 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
#else
|
||||
return ncols <= 16 ? 256 : 128;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------------------------------------------------
|
||||
@@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
||||
|
||||
auto load = [&] __device__ (const int n) {
|
||||
auto load = [&] __device__ (auto n) {
|
||||
const int stride_k = WARP_SIZE >> n;
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||
@@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
@@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
||||
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
|
||||
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
||||
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
||||
@@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
||||
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
||||
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||
constexpr bool use_cp_async = true;
|
||||
cp_async_wait_all();
|
||||
__syncthreads();
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
|
||||
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
||||
} else {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
@@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
|
||||
const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
|
||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
||||
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
||||
const int k0_diff = k0_stop - k0_start;
|
||||
|
||||
if (nstages <= 1) {
|
||||
@@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
||||
}
|
||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
|
||||
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
|
||||
const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
|
||||
if (nstages <= 1) {
|
||||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over V in reverse and re-use the data if possible.
|
||||
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
||||
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
||||
#pragma unroll
|
||||
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
||||
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
|
||||
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
||||
@@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
||||
|
||||
// Calculate VKQ tile:
|
||||
#pragma unroll
|
||||
@@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
||||
|
||||
tile_A A;
|
||||
load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
if (ntiles == 1) {
|
||||
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
||||
} else {
|
||||
@@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
@@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr int cols_per_warp = ntiles * tile_B::I;
|
||||
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
||||
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
||||
|
||||
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
||||
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = c::nbatch_K2 + 4;
|
||||
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
||||
|
||||
extern __shared__ half2 tile_Q[];
|
||||
@@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
||||
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
||||
constexpr bool use_cp_async = true;
|
||||
if (ncols2 > 1 || mask_h2) {
|
||||
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
||||
}
|
||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
|
||||
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
||||
}
|
||||
|
||||
// Iterate over ne11 == previous tokens:
|
||||
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
||||
constexpr bool last_iter = false;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
}
|
||||
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||
constexpr bool last_iter = true;
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
||||
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
||||
}
|
||||
@@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
||||
|
||||
constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
|
||||
constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
|
||||
constexpr int tile_stride = nbatch_combine + 4;
|
||||
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
||||
|
||||
@@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
@@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
|
||||
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
||||
|
||||
typedef fattn_mma_f16_config<DKQ, DV> c;
|
||||
|
||||
@@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int stride_Q1 = nb01 / sizeof(float2);
|
||||
const int stride_Q2 = nb02 / sizeof(float2);
|
||||
const int stride_K = nb11 / sizeof(half2);
|
||||
const int stride_V = nb21 / sizeof(half2);
|
||||
const int stride_mask = nb31 / sizeof(half2);
|
||||
|
||||
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
||||
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
@@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
@@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
@@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
@@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
@@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
|
||||
typedef fattn_mma_f16_config<DKQ, DV> c;
|
||||
|
||||
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
|
||||
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
|
||||
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
|
||||
|
||||
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
||||
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
@@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
||||
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
||||
|
||||
constexpr bool mla = DKQ == 576;
|
||||
|
||||
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
|
||||
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
|
||||
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
|
||||
|
||||
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
||||
static_assert(DV % tile_A::J == 0, "bad DV");
|
||||
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
||||
|
||||
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
||||
|
||||
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
||||
|
||||
@@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
@@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
template <int DKQ, int DV, int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if constexpr (ncols2 <= 8) {
|
||||
@@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32/ncols2) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
|
||||
if (!new_mma_available(cc)) {
|
||||
return false;
|
||||
}
|
||||
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
||||
|
||||
@@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
@@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
||||
|
||||
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
|
||||
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
||||
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
||||
|
||||
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
||||
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i1 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.x;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
|
||||
@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
|
||||
|
||||
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
||||
|
||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
||||
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
|
||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
|
||||
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
|
||||
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
|
||||
|
||||
// Load 4 floats per thread and calculate max. abs. value between them:
|
||||
@@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
|
||||
|
||||
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
|
||||
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
|
||||
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||
switch (mmq_get_q8_1_ds_layout(type_src0)) {
|
||||
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||
|
||||
@@ -4358,7 +4358,7 @@ static bool ggml_metal_encode_node(
|
||||
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
||||
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
||||
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
||||
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
||||
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
|
||||
@@ -3887,6 +3887,11 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
sm[tiisg] = pm[ic + tiisg];
|
||||
}
|
||||
|
||||
// skip -INF blocks
|
||||
if (simd_max(sm[tiisg]) == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
// each simdgroup processes 1 query and NE (NW/NL) head elements
|
||||
|
||||
@@ -49,34 +49,38 @@ endif()
|
||||
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
|
||||
|
||||
# Link against oneDNN
|
||||
find_package(DNNL)
|
||||
set(GGML_SYCL_DNNL 0)
|
||||
if(DNNL_FOUND)
|
||||
if (NOT DEFINED DNNL_GPU_VENDOR)
|
||||
# default to intel target
|
||||
set(DNNL_GPU_VENDOR "INTEL")
|
||||
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
||||
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
||||
if(GGML_SYCL_DNN)
|
||||
find_package(DNNL)
|
||||
if(DNNL_FOUND)
|
||||
if (NOT DEFINED DNNL_GPU_VENDOR)
|
||||
# default to intel target
|
||||
set(DNNL_GPU_VENDOR "INTEL")
|
||||
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
||||
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Verify oneDNN was compiled for the same target as llama
|
||||
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
|
||||
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
||||
set(GGML_SYCL_DNNL 1)
|
||||
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
|
||||
foreach(CONFIG ${CONFIGS})
|
||||
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
|
||||
message(STATUS "Found oneDNN: ${DNNL_LIB}")
|
||||
endforeach()
|
||||
# Verify oneDNN was compiled for the same target as llama
|
||||
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
|
||||
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
||||
set(GGML_SYCL_DNNL 1)
|
||||
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
|
||||
foreach(CONFIG ${CONFIGS})
|
||||
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
|
||||
message(STATUS "Found oneDNN: ${DNNL_LIB}")
|
||||
endforeach()
|
||||
else()
|
||||
message(WARNING
|
||||
"oneDNN must be compiled for the same target as llama.cpp.
|
||||
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
||||
Disabling oneDNN support.")
|
||||
endif()
|
||||
else()
|
||||
message(WARNING
|
||||
"oneDNN must be compiled for the same target as llama.cpp.
|
||||
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
||||
Disabling oneDNN support.")
|
||||
message(STATUS "oneDNN not found, disabling oneDNN support")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "oneDNN not found, disabling oneDNN support")
|
||||
message(STATUS "oneDNN support disabled by the user")
|
||||
endif()
|
||||
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
|
||||
|
||||
|
||||
+109
-220
@@ -1,93 +1,74 @@
|
||||
#include "binbcast.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||
item_ct1.get_local_id(0)) /
|
||||
ne3;
|
||||
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||
item_ct1.get_local_id(0)) %
|
||||
ne3;
|
||||
|
||||
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
for (int i0 = i0s; i0 < ne0;
|
||||
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1,
|
||||
dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) {
|
||||
auto element_id = it.get_global_id(0);
|
||||
auto global_range = it.get_global_range(0);
|
||||
for (; element_id < num_elements; element_id += global_range) {
|
||||
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>();
|
||||
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>();
|
||||
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
|
||||
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
|
||||
dst[element_id] = val_to_store;
|
||||
}
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13,
|
||||
int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
|
||||
int s11, int s12, int s13, std::size_t num_dst_elements,
|
||||
const sycl::nd_item<1> & item_ct1) {
|
||||
auto calculate_logical_index =
|
||||
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> {
|
||||
std::array<int, 4> logical_index;
|
||||
#pragma unroll(4)
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
logical_index[i] = element_id % dims[i];
|
||||
element_id /= dims[i];
|
||||
}
|
||||
return logical_index;
|
||||
};
|
||||
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides,
|
||||
const std::array<int, 4> & indices) __attribute__((always_inline))
|
||||
->std::size_t {
|
||||
std::size_t index = 0;
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto index_i = indices[i];
|
||||
if (indices[i] >= dims[i]) {
|
||||
index_i = indices[i] % dims[i];
|
||||
}
|
||||
index += strides[i] * index_i;
|
||||
}
|
||||
return index;
|
||||
};
|
||||
|
||||
const int i3 = i/(ne2*ne1*ne0);
|
||||
const int i2 = (i/(ne1*ne0)) % ne2;
|
||||
const int i1 = (i/ne0) % ne1;
|
||||
const int i0 = i % ne0;
|
||||
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
auto element_id = item_ct1.get_global_id(0);
|
||||
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) {
|
||||
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id);
|
||||
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
|
||||
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
|
||||
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
|
||||
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>();
|
||||
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
|
||||
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
|
||||
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
|
||||
dst[dst_index] = val_to_store;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
}
|
||||
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
struct bin_bcast_sycl {
|
||||
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
|
||||
template <typename src0_t, typename src1_t, typename dst_t>
|
||||
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
|
||||
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
|
||||
@@ -96,165 +77,73 @@ struct bin_bcast_sycl {
|
||||
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
|
||||
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
|
||||
int nr0 = ne10 / ne0;
|
||||
int nr1 = ne11/ne1;
|
||||
int nr2 = ne12/ne2;
|
||||
int nr3 = ne13/ne3;
|
||||
|
||||
int nr[4] = { nr0, nr1, nr2, nr3 };
|
||||
|
||||
// collapse dimensions until first broadcast dimension
|
||||
int64_t cne[] = {ne0, ne1, ne2, ne3};
|
||||
int64_t cne0[] = {ne00, ne01, ne02, ne03};
|
||||
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
||||
size_t cnb[] = {nb0, nb1, nb2, nb3};
|
||||
size_t cnb0[] = {nb00, nb01, nb02, nb03};
|
||||
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
||||
auto collapse = [](int64_t cne[]) {
|
||||
cne[0] *= cne[1];
|
||||
cne[1] = cne[2];
|
||||
cne[2] = cne[3];
|
||||
cne[3] = 1;
|
||||
};
|
||||
|
||||
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
||||
cnb[1] *= cne[1];
|
||||
cnb[2] *= cne[2];
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
|
||||
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims,
|
||||
const std::array<int64_t, 4> & dst_dims) -> bool {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
}
|
||||
if (i > 0) {
|
||||
collapse_nb(cnb, cne);
|
||||
collapse_nb(cnb0, cne0);
|
||||
collapse_nb(cnb1, cne1);
|
||||
collapse(cne);
|
||||
collapse(cne0);
|
||||
collapse(cne1);
|
||||
if (dst_dims[i] > src_dims[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
int64_t ne0 = cne[0];
|
||||
int64_t ne1 = cne[1];
|
||||
int64_t ne2 = cne[2];
|
||||
int64_t ne3 = cne[3];
|
||||
return false;
|
||||
};
|
||||
|
||||
int64_t ne10 = cne1[0];
|
||||
int64_t ne11 = cne1[1];
|
||||
int64_t ne12 = cne1[2];
|
||||
int64_t ne13 = cne1[3];
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
size_t nb0 = cnb[0];
|
||||
size_t nb1 = cnb[1];
|
||||
size_t nb2 = cnb[2];
|
||||
size_t nb3 = cnb[3];
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
|
||||
size_t nb00 = cnb0[0];
|
||||
size_t nb01 = cnb0[1];
|
||||
size_t nb02 = cnb0[2];
|
||||
size_t nb03 = cnb0[3];
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
|
||||
size_t nb10 = cnb1[0];
|
||||
size_t nb11 = cnb1[1];
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
// dst strides in number of elements
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
// src1 strides in number of elements
|
||||
size_t s10 = nb10 / sizeof(src0_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
size_t s00 = nb00 / sizeof(src0_t);
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
// src0 strides in number of elements
|
||||
size_t s00 = nb00 / sizeof(src0_t);
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
|
||||
GGML_UNUSED(s00);
|
||||
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) *
|
||||
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3);
|
||||
std::size_t local_range = 256;
|
||||
std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range;
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
|
||||
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
|
||||
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
|
||||
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
||||
|
||||
sycl::range<3> block_dims(1, 1, 1);
|
||||
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
||||
block_dims[1] = std::min<unsigned int>(
|
||||
ne1, block_size / (unsigned int)block_dims[2]);
|
||||
block_dims[0] = std::min(
|
||||
std::min<unsigned int>(
|
||||
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
||||
(unsigned int)block_dims[1]),
|
||||
64U);
|
||||
|
||||
sycl::range<3> block_nums(
|
||||
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
||||
(ne1 + block_dims[1] - 1) / block_dims[1],
|
||||
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
||||
|
||||
if (block_nums[0] > 65535) {
|
||||
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
||||
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
||||
sycl::range<3>(1, 1, block_size),
|
||||
sycl::range<3>(1, 1, block_size)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast_unravel<bin_op>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
|
||||
s03, s11, s12, s13, item_ct1);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
/*
|
||||
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
||||
exceed the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if
|
||||
needed.
|
||||
*/
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
||||
ne2, ne3, ne10, ne11, ne12, ne13,
|
||||
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
||||
item_ct1);
|
||||
});
|
||||
}
|
||||
if (! needs_broadcasting && all_contiguous) {
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
|
||||
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
|
||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
|
||||
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
|
||||
const int64_t nb = k / QK_K;
|
||||
const size_t local_size = 32;
|
||||
const size_t global_size = nb * local_size;
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
@@ -504,7 +522,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
||||
case GGML_TYPE_Q3_K:
|
||||
return dequantize_row_q3_K_sycl;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return dequantize_row_q4_K_sycl;
|
||||
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q4_K_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_row_q4_K_sycl;
|
||||
}
|
||||
case GGML_TYPE_Q5_K:
|
||||
return dequantize_row_q5_K_sycl;
|
||||
case GGML_TYPE_Q6_K:
|
||||
@@ -556,7 +578,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
case GGML_TYPE_Q3_K:
|
||||
return dequantize_row_q3_K_sycl;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return dequantize_row_q4_K_sycl;
|
||||
if (dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q4_K_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_row_q4_K_sycl;
|
||||
}
|
||||
case GGML_TYPE_Q5_K:
|
||||
return dequantize_row_q5_K_sycl;
|
||||
case GGML_TYPE_Q6_K:
|
||||
|
||||
@@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename dst_t>
|
||||
inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,
|
||||
const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) {
|
||||
const int is = 2 * il;
|
||||
constexpr int n = 4;
|
||||
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(is + 0, scales_local, sc, m);
|
||||
const float d1 = dall * sc;
|
||||
const float m1 = dmin * m;
|
||||
|
||||
get_scale_min_k4(is + 1, scales_local, sc, m);
|
||||
const float d2 = dall * sc;
|
||||
const float m2 = dmin * m;
|
||||
|
||||
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);
|
||||
for (int l = 0; l < n; ++l) {
|
||||
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
|
||||
y[l + 32] = d2 * (q_vec[l] >> 4) - m2;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
|
||||
@@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
|
||||
#if QK_K == 256
|
||||
// assume 32 threads
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t il = tid/8;
|
||||
const int64_t ir = tid%8;
|
||||
const int64_t is = 2*il;
|
||||
const int64_t n = 4;
|
||||
const int64_t il = tid / 8;
|
||||
const int64_t ir = tid % 8;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
|
||||
|
||||
const sycl::half2 dm = x[i].dm;
|
||||
const float dall = dm[0];
|
||||
const float dmin = dm[1];
|
||||
|
||||
if (tid < 12)
|
||||
if (tid < 12) {
|
||||
scales_local[tid] = x[i].scales[tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
uint8_t sc, m;
|
||||
get_scale_min_k4(is + 0, scales_local, sc, m);
|
||||
const float d1 = dall * sc;
|
||||
const float m1 = dmin * m;
|
||||
get_scale_min_k4(is + 1, scales_local, sc, m);
|
||||
const float d2 = dall * sc;
|
||||
const float m2 = dmin * m;
|
||||
|
||||
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
|
||||
for (int l = 0; l < n; ++l) {
|
||||
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
|
||||
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir);
|
||||
#else
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const uint8_t * q = x[i].qs;
|
||||
@@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,
|
||||
const sycl::nd_item<1> & item_ct1, int64_t nb) {
|
||||
const int64_t i = item_ct1.get_group(0); // block index
|
||||
const int64_t tid = item_ct1.get_local_id(0); // thread index within block
|
||||
const int64_t il = tid / 8;
|
||||
const int64_t ir = tid % 8;
|
||||
|
||||
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
|
||||
|
||||
const uint8_t * base = static_cast<const uint8_t *>(vx);
|
||||
const size_t qs_offset = i * (QK_K / 2);
|
||||
const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;
|
||||
const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);
|
||||
|
||||
const uint8_t * qs_ptr = base + qs_offset;
|
||||
const uint8_t * scales_ptr = base + scales_offset;
|
||||
ggml_half2 dm_values = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);
|
||||
|
||||
const float dall = dm_values.x();
|
||||
const float dmin = dm_values.y();
|
||||
|
||||
if (tid < 12) {
|
||||
scales_local[tid] = scales_ptr[tid];
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
@@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
// reorder is currently not supported for dmmv
|
||||
GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
|
||||
} else {
|
||||
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
|
||||
@@ -655,7 +655,6 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -688,7 +687,6 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -722,7 +720,6 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -754,7 +751,6 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -786,7 +782,6 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -818,7 +813,6 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -850,7 +844,6 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -883,7 +876,6 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -917,7 +909,6 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -949,7 +940,6 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -981,7 +971,6 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1013,7 +1002,6 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1045,7 +1033,6 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1078,7 +1065,6 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1110,7 +1096,6 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1142,7 +1127,6 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1174,7 +1158,6 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1206,7 +1189,6 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1241,7 +1223,6 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1273,7 +1254,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1315,7 +1295,6 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1350,7 +1329,6 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1388,7 +1366,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,16 +32,36 @@ public:
|
||||
else static_assert(0);
|
||||
}
|
||||
|
||||
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
|
||||
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
||||
// matrix A has m rows, k columns
|
||||
// matrix B has k rows, n columns
|
||||
// nra - number of elements to skip when moving into next row in A
|
||||
// nrb - number of elements to skip when moving into next row in B
|
||||
// nca - number of elements to skip when moving into next column in A
|
||||
// ncb - number of elements to skip when moving into next column in B
|
||||
// stride_a - number of elements to skip when moving to next A matrix
|
||||
// stride_b - number of elements to skip when moving to next B matrix
|
||||
// batches_a - number of A matrices
|
||||
// batches_b - number of B matrices
|
||||
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
||||
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
|
||||
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
|
||||
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
|
||||
|
||||
auto stream = ctx.stream_dnnl(q);
|
||||
auto eng = ctx.engine_dnnl(q);
|
||||
dnnl::memory::dims a_dims = { m, k };
|
||||
dnnl::memory::dims b_dims = { k, n };
|
||||
dnnl::memory::dims c_dims = { m, n };
|
||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||
|
||||
// { # strides, # rows, # columns }
|
||||
dnnl::memory::dims a_dims = { batches_a, m, k };
|
||||
dnnl::memory::dims b_dims = { batches_b, k, n };
|
||||
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
|
||||
|
||||
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
|
||||
dnnl::memory::dims a_strides = { stride_a, nra, nca };
|
||||
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
|
||||
|
||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
|
||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
|
||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
|
||||
|
||||
dnnl::primitive_attr primitive_attr;
|
||||
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
@@ -63,6 +83,15 @@ public:
|
||||
|
||||
matmul_prim.execute(stream, matmul_args);
|
||||
}
|
||||
|
||||
// matrices A and B are column major, both having k rows
|
||||
// matrix A has m column, matrix B has n columns
|
||||
// output: column major matrix C = A transposed * B
|
||||
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
||||
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
||||
|
||||
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
|
||||
int g_ggml_sycl_debug = 0;
|
||||
int g_ggml_sycl_disable_optimize = 0;
|
||||
int g_ggml_sycl_disable_graph = 0;
|
||||
int g_ggml_sycl_disable_dnn = 0;
|
||||
int g_ggml_sycl_prioritize_dmmv = 0;
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
@@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
|
||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
|
||||
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||
#ifdef GGML_SYCL_GRAPH
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
|
||||
#endif
|
||||
#if GGML_SYCL_DNNL
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
|
||||
#else
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
|
||||
#endif
|
||||
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
||||
GGML_LOG_INFO("Build with Macros:\n");
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
@@ -341,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) {
|
||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
||||
@@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
int id;
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(id = get_current_device_id()));
|
||||
#if !GGML_SYCL_DNNL
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
const int64_t ne0 = dst->ne[0]; // used by MKL only
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// ldc == nrows of the matrix that cuBLAS writes into
|
||||
int ldc = id == ctx.device ? ne0 : row_diff;
|
||||
#endif
|
||||
int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
bool use_fp16 = true; // TODO(Yu) SYCL capability check
|
||||
@@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
: src1_as_f16.get();
|
||||
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
||||
|
||||
#if !GGML_SYCL_DNNL
|
||||
const sycl::half alpha_f16 = 1.0f;
|
||||
const sycl::half beta_f16 = 0.0f;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
||||
*stream, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
||||
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
||||
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
||||
dpct::library_data_t::real_half)));
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
#else
|
||||
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||
#if GGML_SYCL_DNNL
|
||||
if (!g_ggml_sycl_disable_dnn) {
|
||||
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
const sycl::half alpha_f16 = 1.0f;
|
||||
const sycl::half beta_f16 = 0.0f;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
||||
*stream, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
||||
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
||||
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
||||
dpct::library_data_t::real_half)));
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
||||
@@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
||||
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
|
||||
|
||||
#if !GGML_SYCL_DNNL
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
||||
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
||||
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
||||
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
||||
#else
|
||||
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
||||
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
||||
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
||||
#if GGML_SYCL_DNNL
|
||||
if (!g_ggml_sycl_disable_dnn) {
|
||||
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
||||
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
||||
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
||||
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
||||
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
||||
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_ddq_i);
|
||||
@@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
|
||||
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
|
||||
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
||||
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
||||
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
||||
@@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
|
||||
|
||||
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
||||
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
||||
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
|
||||
uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
|
||||
|
||||
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
||||
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
||||
@@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
@@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
}
|
||||
|
||||
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
||||
char * dst_t = reinterpret_cast<char *>(dst_ddf);
|
||||
|
||||
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
||||
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
||||
@@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
|
||||
GGML_ASSERT(ne10 == ne00);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
const int64_t r3 = ne13 / ne03;
|
||||
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
||||
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
|
||||
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
||||
} else {
|
||||
const int ne23 = ne12 * ne13;
|
||||
#if GGML_SYCL_DNNL
|
||||
if (!g_ggml_sycl_disable_dnn) {
|
||||
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
|
||||
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
|
||||
|
||||
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
||||
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
||||
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
||||
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
|
||||
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
|
||||
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
|
||||
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
|
||||
};
|
||||
|
||||
sycl::range<3> block_dims(1, ne12, ne13);
|
||||
queue->submit([&](sycl::handler & cgh) {
|
||||
const void ** ptrs_src_get = ptrs_src.get();
|
||||
void ** ptrs_dst_get = ptrs_dst.get();
|
||||
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
||||
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
||||
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
||||
if (r2 == 1 && r3 == 1) {
|
||||
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
|
||||
}
|
||||
else {
|
||||
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
||||
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
|
||||
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
|
||||
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
|
||||
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// iterate over batches from smaller set of matrices (matrix 0)
|
||||
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
|
||||
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
||||
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
|
||||
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
|
||||
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
|
||||
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
||||
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
|
||||
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
||||
} else {
|
||||
const int ne23 = ne12 * ne13;
|
||||
|
||||
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
||||
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
||||
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
||||
|
||||
sycl::range<3> block_dims(1, ne12, ne13);
|
||||
queue->submit([&](sycl::handler & cgh) {
|
||||
const void ** ptrs_src_get = ptrs_src.get();
|
||||
void ** ptrs_dst_get = ptrs_dst.get();
|
||||
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
||||
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
||||
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
||||
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
||||
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
||||
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
||||
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
||||
}
|
||||
}
|
||||
} catch (const sycl::exception & exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||
@@ -2841,6 +2900,8 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
case GGML_TYPE_Q4_K:
|
||||
return !g_ggml_sycl_prioritize_dmmv;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2858,6 +2919,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
|
||||
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -2883,16 +2945,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
}
|
||||
}
|
||||
|
||||
static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
|
||||
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
|
||||
dpct::queue_ptr stream) {
|
||||
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
|
||||
.wait()));
|
||||
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
||||
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
||||
int offset_blks = offset / sizeof(block_q4_0);
|
||||
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
|
||||
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
|
||||
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
||||
|
||||
stream->parallel_for(
|
||||
@@ -2906,18 +2968,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
|
||||
}
|
||||
*(d_ptr + ib) = x[ib].d;
|
||||
});
|
||||
}).wait_and_throw();
|
||||
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
|
||||
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
|
||||
|
||||
const int nblocks = size / sizeof(block_q4_K);
|
||||
|
||||
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
||||
|
||||
auto * qs_ptr = data_device;
|
||||
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
|
||||
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
|
||||
|
||||
stream->parallel_for(nblocks, [=](auto i) {
|
||||
const block_q4_K * x = (const block_q4_K *) tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
for (int j = 0; j < QK_K / 2; ++j) {
|
||||
qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < K_SCALE_SIZE; ++j) {
|
||||
scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
|
||||
}
|
||||
|
||||
dm_ptr[ib] = x[ib].dm;
|
||||
}).wait_and_throw();
|
||||
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
char*data_device = (char*)src0->data;
|
||||
uint8_t * data_device = (uint8_t *) src0->data;
|
||||
size_t ncols = src0->ne[0];
|
||||
size_t nrows = src0->ne[1];
|
||||
size_t size = ggml_nbytes(src0);
|
||||
|
||||
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
reorder_qw_q4_k(data_device, size, 0, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("reorder_qw() called with unsupported type");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
|
||||
@@ -2960,8 +3063,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
|
||||
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
||||
static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
||||
src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
||||
}
|
||||
|
||||
static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
||||
src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
|
||||
int64_t min_compute_capability = INT_MAX;
|
||||
|
||||
@@ -2984,13 +3097,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
}
|
||||
|
||||
// check data types and tensor shapes for custom matrix multiplication kernels:
|
||||
bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
||||
bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
|
||||
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
|
||||
|
||||
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
@@ -3713,7 +3822,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
|
||||
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
|
||||
|
||||
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||
model_sycl_graph.end_recording();
|
||||
|
||||
@@ -24,6 +24,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
||||
const int blocks_per_row = ncols / block_traits::qk;
|
||||
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
|
||||
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
|
||||
const int nblocks = nrows * (ncols / block_traits::qk);
|
||||
|
||||
static_assert(blocks_per_subgroup > 0);
|
||||
static_assert(block_elements_per_subgroup > 0);
|
||||
@@ -45,7 +46,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
||||
// x block quant index when casting the quants to int
|
||||
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
||||
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -739,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
}
|
||||
}
|
||||
|
||||
static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
||||
const int nrows, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
|
||||
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
||||
constexpr size_t num_subgroups = 16;
|
||||
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
||||
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
|
||||
nrows, nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
@@ -1035,7 +1057,12 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
} else {
|
||||
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
|
||||
@@ -56,6 +56,28 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
};
|
||||
|
||||
template <> struct block_q_t<GGML_TYPE_Q4_K> {
|
||||
struct traits {
|
||||
static constexpr uint32_t qk = QK_K;
|
||||
static constexpr uint32_t qi = QI4_K;
|
||||
static constexpr uint32_t qr = QR4_K;
|
||||
static constexpr uint32_t vdr_mmvq = 2;
|
||||
};
|
||||
|
||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||
|
||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
auto nblocks = (nrows * (ncols / traits::qk));
|
||||
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
|
||||
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
|
||||
|
||||
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
|
||||
};
|
||||
|
||||
} // namespace ggml_sycl_reordered
|
||||
|
||||
#endif // GGML_SYCL_QUANTS_HPP
|
||||
|
||||
@@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
||||
}
|
||||
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
|
||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
|
||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
|
||||
int v[q4_0_traits::vdr_mmvq];
|
||||
@@ -303,6 +303,67 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
||||
};
|
||||
};
|
||||
|
||||
static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
|
||||
const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1,
|
||||
const int & iqs) {
|
||||
int v[2];
|
||||
int u[2 * QR4_K];
|
||||
float d8[QR4_K];
|
||||
|
||||
v[0] = q4[0];
|
||||
v[1] = q4[4];
|
||||
|
||||
uint16_t aux[2];
|
||||
const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
|
||||
if (j < 2) {
|
||||
aux[0] = scales[j + 0] & 0x3f3f;
|
||||
aux[1] = scales[j + 2] & 0x3f3f;
|
||||
} else {
|
||||
aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
|
||||
aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
|
||||
}
|
||||
|
||||
const uint8_t * sc = (const uint8_t *) aux;
|
||||
const uint8_t * m = sc + 2;
|
||||
|
||||
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
||||
|
||||
for (int i = 0; i < QR4_K; ++i) {
|
||||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
||||
d8[i] = bq8i->ds[0];
|
||||
|
||||
const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4);
|
||||
u[2 * i + 0] = q8[0];
|
||||
u[2 * i + 1] = q8[4];
|
||||
}
|
||||
|
||||
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8);
|
||||
}
|
||||
|
||||
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
|
||||
static constexpr ggml_type gtype = GGML_TYPE_Q4_K;
|
||||
|
||||
using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
|
||||
using q4_k_traits = typename q4_k_block::traits;
|
||||
|
||||
float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
|
||||
const int ib = ibx_offset / (QK_K / 2);
|
||||
|
||||
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
||||
const uint8_t * qs = base + ibx_offset;
|
||||
const int total_qs_bytes = nblocks * (QK_K / 2);
|
||||
const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
|
||||
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
|
||||
|
||||
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
||||
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
|
||||
const uint16_t * scales = (const uint16_t *) scs;
|
||||
|
||||
return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
|
||||
}
|
||||
};
|
||||
|
||||
#define VDR_Q4_0_Q8_1_MMVQ 2
|
||||
#define VDR_Q4_0_Q8_1_MMQ 4
|
||||
|
||||
@@ -649,52 +710,17 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
|
||||
return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
|
||||
}
|
||||
|
||||
static __dpct_inline__ float
|
||||
vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
|
||||
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
|
||||
|
||||
static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
|
||||
const int & iqs) {
|
||||
#ifndef GGML_QKK_64
|
||||
|
||||
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
||||
|
||||
int v[2];
|
||||
int u[2*QR4_K];
|
||||
float d8[QR4_K];
|
||||
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
||||
const int * q4 = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
|
||||
const uint16_t * scales = (const uint16_t *) bq4_K->scales;
|
||||
|
||||
// iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
|
||||
const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
|
||||
|
||||
// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
|
||||
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
|
||||
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
|
||||
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
|
||||
|
||||
const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
|
||||
v[0] = q4[0];
|
||||
v[1] = q4[4];
|
||||
|
||||
const uint16_t * scales = (const uint16_t *)bq4_K->scales;
|
||||
uint16_t aux[2];
|
||||
const int j = bq8_offset/2;
|
||||
if (j < 2) {
|
||||
aux[0] = scales[j+0] & 0x3f3f;
|
||||
aux[1] = scales[j+2] & 0x3f3f;
|
||||
} else {
|
||||
aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
|
||||
aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
|
||||
}
|
||||
const uint8_t * sc = (const uint8_t *)aux;
|
||||
const uint8_t * m = sc + 2;
|
||||
|
||||
for (int i = 0; i < QR4_K; ++i) {
|
||||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
|
||||
d8[i] = bq8i->ds[0];
|
||||
|
||||
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
|
||||
u[2*i+0] = q8[0];
|
||||
u[2*i+1] = q8[4];
|
||||
}
|
||||
|
||||
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
|
||||
return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs);
|
||||
|
||||
#else
|
||||
|
||||
|
||||
@@ -15,6 +15,32 @@ function(detect_host_compiler)
|
||||
set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Function to test shader extension support
|
||||
# Parameters:
|
||||
# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product")
|
||||
# TEST_SHADER_FILE - Path to the test shader file
|
||||
# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result
|
||||
function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE)
|
||||
execute_process(
|
||||
COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error
|
||||
)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*")
|
||||
message(STATUS "${EXTENSION_NAME} not supported by glslc")
|
||||
set(${RESULT_VARIABLE} OFF PARENT_SCOPE)
|
||||
else()
|
||||
message(STATUS "${EXTENSION_NAME} supported by glslc")
|
||||
set(${RESULT_VARIABLE} ON PARENT_SCOPE)
|
||||
add_compile_definitions(${RESULT_VARIABLE})
|
||||
|
||||
# Ensure the extension support is forwarded to vulkan-shaders-gen
|
||||
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON)
|
||||
set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
if (Vulkan_FOUND)
|
||||
message(STATUS "Vulkan found")
|
||||
|
||||
@@ -23,69 +49,35 @@ if (Vulkan_FOUND)
|
||||
../../include/ggml-vulkan.h
|
||||
)
|
||||
|
||||
# Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
set(VULKAN_SHADER_GEN_CMAKE_ARGS
|
||||
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
|
||||
-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
|
||||
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF)
|
||||
else()
|
||||
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
endif()
|
||||
# Test all shader extensions
|
||||
test_shader_extension_support(
|
||||
"GL_KHR_cooperative_matrix"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
|
||||
"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
test_shader_extension_support(
|
||||
"GL_NV_cooperative_matrix2"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
|
||||
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
|
||||
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF)
|
||||
else()
|
||||
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
endif()
|
||||
test_shader_extension_support(
|
||||
"GL_EXT_integer_dot_product"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
|
||||
"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
# Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
|
||||
message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
|
||||
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF)
|
||||
else()
|
||||
message(STATUS "GL_EXT_integer_dot_product supported by glslc")
|
||||
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON)
|
||||
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
endif()
|
||||
|
||||
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
|
||||
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
|
||||
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
|
||||
else()
|
||||
message(STATUS "GL_EXT_bfloat16 supported by glslc")
|
||||
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
|
||||
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
endif()
|
||||
test_shader_extension_support(
|
||||
"GL_EXT_bfloat16"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
|
||||
"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
||||
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
||||
@@ -124,16 +116,8 @@ if (Vulkan_FOUND)
|
||||
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
||||
endif()
|
||||
|
||||
if (NOT CMAKE_CROSSCOMPILING)
|
||||
add_subdirectory(vulkan-shaders)
|
||||
if (MSVC)
|
||||
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
|
||||
string(TOUPPER ${CONFIG} CONFIG)
|
||||
set_target_properties(vulkan-shaders-gen PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||
endforeach()
|
||||
endif()
|
||||
else()
|
||||
# Set up toolchain for host compilation whether cross-compiling or not
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
|
||||
set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
|
||||
else()
|
||||
@@ -146,25 +130,31 @@ if (Vulkan_FOUND)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
|
||||
set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
|
||||
endif()
|
||||
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
|
||||
|
||||
include(ExternalProject)
|
||||
# Native build through ExternalProject_Add
|
||||
ExternalProject_Add(
|
||||
vulkan-shaders-gen
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
|
||||
CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
|
||||
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
|
||||
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
|
||||
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
|
||||
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
|
||||
-DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
||||
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
||||
)
|
||||
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
|
||||
else()
|
||||
# For non-cross-compiling, use empty toolchain (use host compiler)
|
||||
set(HOST_CMAKE_TOOLCHAIN_FILE "")
|
||||
endif()
|
||||
|
||||
# Always use ExternalProject_Add approach
|
||||
include(ExternalProject)
|
||||
|
||||
# Add toolchain file if cross-compiling
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
|
||||
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
|
||||
endif()
|
||||
|
||||
# Native build through ExternalProject_Add
|
||||
ExternalProject_Add(
|
||||
vulkan-shaders-gen
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
|
||||
CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS}
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
||||
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
||||
)
|
||||
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
|
||||
|
||||
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
|
||||
set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
|
||||
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
|
||||
@@ -175,9 +165,8 @@ if (Vulkan_FOUND)
|
||||
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
|
||||
set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
|
||||
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
|
||||
endif()
|
||||
# Add build and install dependencies for all builds
|
||||
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${_ggml_vk_header}
|
||||
|
||||
@@ -288,6 +288,9 @@ struct vk_device_struct {
|
||||
bool coopmat_acc_f32_support {};
|
||||
bool coopmat_acc_f16_support {};
|
||||
bool coopmat_bf16_support {};
|
||||
bool coopmat_support_16x16x16_f16acc {};
|
||||
bool coopmat_support_16x16x16_f32acc {};
|
||||
bool coopmat1_fa_support {};
|
||||
uint32_t coopmat_m;
|
||||
uint32_t coopmat_n;
|
||||
uint32_t coopmat_k;
|
||||
@@ -410,6 +413,13 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
||||
@@ -1588,19 +1598,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
||||
);
|
||||
}
|
||||
|
||||
enum FaCodePath {
|
||||
FA_SCALAR,
|
||||
FA_COOPMAT1,
|
||||
FA_COOPMAT2,
|
||||
};
|
||||
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
||||
|
||||
static uint32_t get_fa_num_small_rows(bool scalar) {
|
||||
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
|
||||
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
||||
// 128 threads split into four subgroups, each subgroup does 1/4
|
||||
// of the Bc dimension.
|
||||
static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
|
||||
static constexpr uint32_t scalar_flash_attention_Bc = 64;
|
||||
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
|
||||
|
||||
static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
||||
if (path == FA_COOPMAT2) {
|
||||
return flash_attention_num_small_rows;
|
||||
} else {
|
||||
return scalar_flash_attention_num_small_rows;
|
||||
}
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
GGML_UNUSED(clamp);
|
||||
|
||||
if (scalar) {
|
||||
if (path == FA_SCALAR) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, 64};
|
||||
} else {
|
||||
@@ -1608,9 +1635,17 @@ static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t cl
|
||||
}
|
||||
}
|
||||
|
||||
if (path == FA_COOPMAT1) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
|
||||
} else {
|
||||
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
|
||||
}
|
||||
}
|
||||
|
||||
// small rows, large cols
|
||||
if (small_rows) {
|
||||
return {get_fa_num_small_rows(scalar), 32};
|
||||
return {get_fa_num_small_rows(FA_COOPMAT2), 32};
|
||||
}
|
||||
|
||||
// small cols to reduce register count
|
||||
@@ -1907,17 +1942,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1};
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
// For scalar, use 128 (arbitrary)
|
||||
uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
|
||||
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
||||
? scalar_flash_attention_workgroup_size
|
||||
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
@@ -1929,36 +1966,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
||||
};
|
||||
|
||||
#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
||||
|
||||
CREATE_FA(GGML_TYPE_F16, f16, true, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
||||
}
|
||||
#endif
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
|
||||
}
|
||||
#endif
|
||||
#undef CREATE_FA2
|
||||
@@ -2041,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
@@ -3009,6 +3053,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
#if defined(VK_KHR_cooperative_matrix)
|
||||
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
||||
|
||||
// coopmat1 fa shader currently assumes 32 invocations per subgroup
|
||||
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
|
||||
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
|
||||
device->subgroup_max_size >= 32;
|
||||
#endif
|
||||
|
||||
if (coopmat2_support) {
|
||||
@@ -3143,6 +3192,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
// Only enable if shape is identical
|
||||
device->coopmat_acc_f32_support = true;
|
||||
}
|
||||
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
||||
device->coopmat_support_16x16x16_f32acc = true;
|
||||
}
|
||||
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
||||
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
||||
// coopmat sizes not set yet
|
||||
@@ -3155,6 +3207,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
// Only enable if shape is identical
|
||||
device->coopmat_acc_f16_support = true;
|
||||
}
|
||||
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
||||
device->coopmat_support_16x16x16_f16acc = true;
|
||||
}
|
||||
}
|
||||
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
|
||||
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
|
||||
@@ -5688,6 +5743,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
|
||||
const uint32_t acctype = f32acc ? 4 : 2;
|
||||
const uint32_t f16vec4 = 8;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
||||
|
||||
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
||||
|
||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||
|
||||
const uint32_t kshstride = D / 4 + 2;
|
||||
const uint32_t ksh = Bc * kshstride * f16vec4;
|
||||
|
||||
const uint32_t slope = Br * sizeof(float);
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
||||
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
||||
@@ -5738,7 +5823,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
assert(q->type == GGML_TYPE_F32);
|
||||
assert(k->type == v->type);
|
||||
|
||||
bool scalar = !ctx->device->coopmat2;
|
||||
FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
|
||||
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||
|
||||
if (path == FA_COOPMAT1) {
|
||||
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
||||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
||||
|
||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
||||
|
||||
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
@@ -5746,9 +5843,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
// For scalar FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat FA, we always use the small size (which is still pretty large for gqa).
|
||||
const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
uint32_t max_gqa;
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
case FA_COOPMAT1:
|
||||
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
||||
max_gqa = scalar_flash_attention_num_large_rows;
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(0);
|
||||
}
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||
@@ -5761,11 +5870,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
|
||||
vk_pipeline *pipelines;
|
||||
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
|
||||
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
|
||||
bool small_rows = N <= get_fa_num_small_rows(scalar);
|
||||
bool small_rows = N <= get_fa_num_small_rows(path);
|
||||
|
||||
if (scalar) {
|
||||
if (small_rows && path == FA_COOPMAT1) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
||||
@@ -5777,7 +5891,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
case FA_COOPMAT1:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
@@ -5789,6 +5917,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(0);
|
||||
}
|
||||
assert(pipelines);
|
||||
|
||||
|
||||
@@ -5,18 +5,35 @@ find_package (Threads REQUIRED)
|
||||
|
||||
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling coopmat glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling coopmat2 glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling dot glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling bfloat16 glslc support")
|
||||
endif()
|
||||
|
||||
set(TARGET vulkan-shaders-gen)
|
||||
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
|
||||
|
||||
# Configure output directories for MSVC builds
|
||||
if(MSVC)
|
||||
# Get the main project's runtime output directory if possible
|
||||
if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY)
|
||||
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
|
||||
string(TOUPPER ${CONFIG} CONFIG)
|
||||
set_target_properties(${TARGET} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
@@ -19,7 +20,7 @@ layout (constant_id = 3) const uint32_t D = 32;
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split;
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
@@ -134,8 +135,8 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared vec4 tmpshv4[gl_WorkGroupSize.x];
|
||||
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
||||
shared vec4 tmpshv4[WorkGroupSize];
|
||||
|
||||
shared float masksh[Bc][Br];
|
||||
shared vec4 Qf[Br][D / 4];
|
||||
|
||||
@@ -0,0 +1,506 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
|
||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
||||
|
||||
const uint32_t D_per_thread = D / D_split;
|
||||
const uint32_t row_split = 4;
|
||||
const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
||||
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
||||
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
||||
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * D + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
||||
const uint32_t MatBr = 16;
|
||||
const uint32_t MatBc = 16;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||
|
||||
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
// Avoid padding for D==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 ksh[Bc * kshstride];
|
||||
|
||||
shared float slope[Br];
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t N = p.N;
|
||||
const uint32_t KV = p.KV;
|
||||
|
||||
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
||||
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
|
||||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
uint32_t i = gl_WorkGroupID.x;
|
||||
uint32_t split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
||||
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
const uint32_t iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
const uint32_t rk2 = p.neq2/p.nek2;
|
||||
const uint32_t rk3 = p.neq3/p.nek3;
|
||||
|
||||
const uint32_t rv2 = p.neq2/p.nev2;
|
||||
const uint32_t rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
const uint32_t ik3 = iq3 / rk3;
|
||||
const uint32_t ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const uint32_t iv3 = iq3 / rv3;
|
||||
const uint32_t iv2 = iq2 / rv2;
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t r = (idx + tid) / (D / 4);
|
||||
if (r < Br && d < D / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = ACC_TYPEV4(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
float Lf[rows_per_thread], Mf[rows_per_thread];
|
||||
|
||||
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
||||
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lf[r] = 0;
|
||||
Mf[r] = NEG_FLT_MAX_OVER_2;
|
||||
}
|
||||
|
||||
// ALiBi
|
||||
if (p.max_bias > 0.0f) {
|
||||
if (tid < Br) {
|
||||
uint r = tid;
|
||||
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
|
||||
}
|
||||
barrier();
|
||||
} else {
|
||||
if (tid < Br) {
|
||||
uint r = tid;
|
||||
slope[r] = 1.0;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
||||
#else
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (D / 4);
|
||||
uint32_t c = (idx + tid) / (D / 4);
|
||||
if (c < Bc && d < D / 4) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
#else
|
||||
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
|
||||
ksh[c * kshstride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
for (uint32_t d = 0; d < D / 16; ++d) {
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
|
||||
}
|
||||
|
||||
uint coord = gl_SubgroupID * MatBc * sfshstride;
|
||||
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
barrier();
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / Br;
|
||||
uint32_t r = (idx + tid) % Br;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (p.mask != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
float eMf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
|
||||
}
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// P = e^(S - M)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf, Moldf);
|
||||
eMf[r] = exp(Moldf - Mf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lf[r] = eMf[r]*Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
float Pf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||
Lf[r] += Pf[r];
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
// reduce across threads
|
||||
|
||||
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
FLOAT_TYPE M = Mf[r];
|
||||
tmpsh[tid] = M;
|
||||
// Compute max across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||
M = max(M, tmpsh[tid ^ s]);
|
||||
barrier();
|
||||
tmpsh[tid] = M;
|
||||
barrier();
|
||||
}
|
||||
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Moldf[r] = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf[r], Moldf[r]);
|
||||
eMf[r] = exp(Moldf[r] - Mf[r]);
|
||||
|
||||
Lf[r] = eMf[r]*Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
FLOAT_TYPE L = Lf[r];
|
||||
tmpsh[tid] = L;
|
||||
// Compute sum across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||
L += tmpsh[tid ^ s];
|
||||
barrier();
|
||||
tmpsh[tid] = L;
|
||||
barrier();
|
||||
}
|
||||
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
||||
Of[r][d] += tmpshv4[tid ^ s];
|
||||
barrier();
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
barrier();
|
||||
}
|
||||
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
float Lfrcp[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= float16_t(Lfrcp[r]);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (i * Br + tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -215,7 +215,7 @@ static std::mutex compile_count_mutex;
|
||||
static std::condition_variable compile_count_cond;
|
||||
|
||||
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
||||
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
||||
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
||||
std::string out_fname = join_paths(output_dir, name + ".spv");
|
||||
std::string in_path = join_paths(input_dir, in_fname);
|
||||
|
||||
@@ -424,6 +424,7 @@ void process_shaders() {
|
||||
// flash attention
|
||||
for (const auto& f16acc : {false, true}) {
|
||||
std::string acctype = f16acc ? "float16_t" : "float";
|
||||
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "f32") {
|
||||
@@ -440,6 +441,16 @@ void process_shaders() {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
}
|
||||
#endif
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
|
||||
+33
-33
@@ -299,10 +299,10 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
|
||||
return false;
|
||||
}
|
||||
} catch (std::length_error &) {
|
||||
fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
|
||||
GGML_LOG_ERROR("%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
|
||||
return false;
|
||||
} catch (std::bad_alloc &) {
|
||||
fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
|
||||
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
|
||||
return false;
|
||||
}
|
||||
kv.emplace_back(key, value);
|
||||
@@ -328,14 +328,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
ok = ok && gr.read(magic, 4);
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read magic\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read magic\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < magic.size(); i++) {
|
||||
if (magic[i] != GGUF_MAGIC[i]) {
|
||||
fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
|
||||
GGML_LOG_ERROR("%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -348,11 +348,11 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
if (ok && gr.read(ctx->version)) {
|
||||
if (ctx->version == 1) {
|
||||
fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
|
||||
GGML_LOG_ERROR("%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
if (ctx->version > GGUF_VERSION) {
|
||||
fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
|
||||
GGML_LOG_ERROR("%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
|
||||
__func__, ctx->version, GGUF_VERSION);
|
||||
ok = false;
|
||||
}
|
||||
@@ -363,7 +363,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
if (ok && gr.read(n_tensors)) {
|
||||
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
|
||||
if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
|
||||
fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
|
||||
GGML_LOG_ERROR("%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
|
||||
__func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
|
||||
ok = false;
|
||||
}
|
||||
@@ -374,7 +374,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
if (ok && gr.read(n_kv)) {
|
||||
static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
|
||||
if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
|
||||
fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
|
||||
GGML_LOG_ERROR("%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
|
||||
__func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
|
||||
ok = false;
|
||||
}
|
||||
@@ -383,7 +383,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read header\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read header\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -399,15 +399,15 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
try {
|
||||
ok = ok && gr.read(key);
|
||||
} catch (std::length_error &) {
|
||||
fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
|
||||
GGML_LOG_ERROR("%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
|
||||
ok = false;
|
||||
} catch (std::bad_alloc &) {
|
||||
fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
|
||||
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
|
||||
ok = false;
|
||||
}
|
||||
for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
|
||||
if (key == ctx->kv[j].key) {
|
||||
fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
|
||||
GGML_LOG_ERROR("%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
|
||||
ok = false;
|
||||
}
|
||||
}
|
||||
@@ -441,14 +441,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
case GGUF_TYPE_ARRAY:
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
|
||||
GGML_LOG_ERROR("%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
|
||||
ok = false;
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read key-value pairs\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -458,7 +458,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
|
||||
|
||||
if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
|
||||
fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
|
||||
GGML_LOG_ERROR("%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -474,14 +474,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
try {
|
||||
ok = ok && gr.read(name);
|
||||
} catch (std::length_error &) {
|
||||
fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||
GGML_LOG_ERROR("%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||
ok = false;
|
||||
} catch (std::bad_alloc &) {
|
||||
fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||
GGML_LOG_ERROR("%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
|
||||
ok = false;
|
||||
}
|
||||
if (name.length() >= GGML_MAX_NAME) {
|
||||
fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
|
||||
GGML_LOG_ERROR("%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
@@ -490,7 +490,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
// make sure there are no duplicate tensor names
|
||||
for (int64_t j = 0; ok && j < i; ++j) {
|
||||
if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
|
||||
fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
|
||||
GGML_LOG_ERROR("%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
@@ -505,7 +505,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
uint32_t n_dims = -1;
|
||||
ok = ok && gr.read(n_dims);
|
||||
if (n_dims > GGML_MAX_DIMS) {
|
||||
fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
|
||||
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
|
||||
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
|
||||
ok = false;
|
||||
break;
|
||||
@@ -518,7 +518,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
// check that all ne are non-negative
|
||||
if (info.t.ne[j] < 0) {
|
||||
fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
|
||||
GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
|
||||
__func__, info.t.name, j, info.t.ne[j]);
|
||||
ok = false;
|
||||
break;
|
||||
@@ -530,7 +530,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
(INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
|
||||
(INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
|
||||
|
||||
fprintf(stderr, "%s: total number of elements in tensor '%s' with shape "
|
||||
GGML_LOG_ERROR("%s: total number of elements in tensor '%s' with shape "
|
||||
"(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
|
||||
__func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
|
||||
ok = false;
|
||||
@@ -547,7 +547,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
// check that tensor type is within defined range
|
||||
if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n",
|
||||
GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n",
|
||||
__func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
|
||||
ok = false;
|
||||
break;
|
||||
@@ -557,7 +557,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
// check that row size is divisible by block size
|
||||
if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
|
||||
fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
|
||||
GGML_LOG_ERROR("%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
|
||||
"not a multiple of block size (%" PRId64 ")\n",
|
||||
__func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
|
||||
ok = false;
|
||||
@@ -582,7 +582,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read tensor info\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read tensor info\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -590,7 +590,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
// we require the data section to be aligned, so take into account any padding
|
||||
if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
|
||||
fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -604,9 +604,9 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
for (size_t i = 0; i < ctx->info.size(); ++i) {
|
||||
const gguf_tensor_info & ti = ctx->info[i];
|
||||
if (ti.offset != ctx->size) {
|
||||
fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
|
||||
GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
|
||||
__func__, ti.t.name, ti.offset, ctx->size);
|
||||
fprintf(stderr, "%s: failed to read tensor data\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -634,7 +634,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
|
||||
*params.ctx = ggml_init(pdata);
|
||||
if (*params.ctx == nullptr) {
|
||||
fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to initialize ggml context for storing tensors\n", __func__);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -656,7 +656,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
ok = ok && gr.read(data->data, ctx->size);
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to read tensor data binary blob\n", __func__);
|
||||
ggml_free(ctx_data);
|
||||
*params.ctx = nullptr;
|
||||
gguf_free(ctx);
|
||||
@@ -689,7 +689,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: failed to create tensors\n", __func__);
|
||||
GGML_LOG_ERROR("%s: failed to create tensors\n", __func__);
|
||||
ggml_free(ctx_data);
|
||||
*params.ctx = nullptr;
|
||||
gguf_free(ctx);
|
||||
@@ -706,7 +706,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
FILE * file = ggml_fopen(fname, "rb");
|
||||
|
||||
if (!file) {
|
||||
fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname);
|
||||
GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -1305,7 +1305,7 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
|
||||
FILE * file = ggml_fopen(fname, "wb");
|
||||
|
||||
if (!file) {
|
||||
fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
|
||||
GGML_LOG_ERROR("%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -823,6 +823,7 @@ class GGUFEditorWindow(QMainWindow):
|
||||
self.modified = False
|
||||
self.metadata_changes = {} # Store changes to apply when saving
|
||||
self.metadata_to_remove = set() # Store keys to remove when saving
|
||||
self.on_metadata_changed_is_connected = False
|
||||
|
||||
self.setup_ui()
|
||||
|
||||
@@ -941,9 +942,11 @@ class GGUFEditorWindow(QMainWindow):
|
||||
return
|
||||
|
||||
# Disconnect to prevent triggering during loading
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore')
|
||||
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
|
||||
if self.on_metadata_changed_is_connected:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore')
|
||||
self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
|
||||
self.on_metadata_changed_is_connected = False
|
||||
|
||||
for i, (key, field) in enumerate(self.reader.fields.items()):
|
||||
self.metadata_table.insertRow(i)
|
||||
@@ -1021,6 +1024,7 @@ class GGUFEditorWindow(QMainWindow):
|
||||
|
||||
# Reconnect after loading
|
||||
self.metadata_table.itemChanged.connect(self.on_metadata_changed)
|
||||
self.on_metadata_changed_is_connected = True
|
||||
|
||||
def extract_array_values(self, field: ReaderField) -> list:
|
||||
"""Extract all values from an array field."""
|
||||
|
||||
@@ -68,7 +68,7 @@ class TensorNameMap:
|
||||
"output_layer", # chatglm
|
||||
"head", # rwkv
|
||||
"head.out", # wavtokenizer
|
||||
"language_model.lm_head", # llama4
|
||||
"lm_head", # llama4
|
||||
),
|
||||
|
||||
# Output norm
|
||||
@@ -91,7 +91,7 @@ class TensorNameMap:
|
||||
"rwkv.ln_out", # rwkv6
|
||||
"model.ln_out", # rwkv7
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
"language_model.model.norm", # llama4
|
||||
"model.norm", # llama4
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
@@ -133,7 +133,7 @@ class TensorNameMap:
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
"language_model.model.layers.{bid}.input_layernorm", # llama4
|
||||
"model.layers.{bid}.input_layernorm", # llama4
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@@ -173,7 +173,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wq", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
|
||||
"transformer.h.{bid}.attn.attention.q_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention key
|
||||
@@ -188,7 +188,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wk", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
|
||||
"transformer.h.{bid}.attn.attention.k_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention value
|
||||
@@ -202,7 +202,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.attention.wv", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
|
||||
"transformer.h.{bid}.attn.attention.v_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention output
|
||||
@@ -229,7 +229,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||
"transformer.h.{bid}.attn.attention.out_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
@@ -268,7 +268,7 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
@@ -289,7 +289,7 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
"language_model.model.layers.{bid}.feed_forward.router", # llama4
|
||||
"model.layers.{bid}.feed_forward.router", # llama4
|
||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||
),
|
||||
|
||||
@@ -329,7 +329,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@@ -338,14 +338,14 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
@@ -366,22 +366,22 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
@@ -410,7 +410,7 @@ class TensorNameMap:
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
@@ -420,15 +420,15 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
|
||||
@@ -113,7 +113,7 @@ parser.add_argument("-o", "--output", help=help_o, default="pipe")
|
||||
help_s = (
|
||||
"Columns to add to the table. "
|
||||
"Accepts a comma-separated list of values. "
|
||||
f"Legal values: {', '.join(KEY_PROPERTIES[:-2])}. "
|
||||
f"Legal values: {', '.join(KEY_PROPERTIES[:-3])}. "
|
||||
"Defaults to model name (model_type) and CPU and/or GPU name (cpu_info, gpu_info) "
|
||||
"plus any column where not all data points are the same. "
|
||||
"If the columns are manually specified, then the results for each unique combination of the "
|
||||
@@ -505,7 +505,7 @@ if known_args.show is not None:
|
||||
show = known_args.show.split(",")
|
||||
unknown_cols = []
|
||||
for prop in show:
|
||||
if prop not in KEY_PROPERTIES[:-2]: # Last two values are n_prompt, n_gen.
|
||||
if prop not in KEY_PROPERTIES[:-3]: # Last three values are n_prompt, n_gen, n_depth.
|
||||
unknown_cols.append(prop)
|
||||
if unknown_cols:
|
||||
logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}")
|
||||
|
||||
@@ -1704,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_write(io);
|
||||
if (kv_self != nullptr) {
|
||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||
kv_self->state_write(io);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
@@ -441,6 +441,13 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||
|
||||
void llama_kv_cache_unified::set_full() {
|
||||
n = size;
|
||||
|
||||
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
||||
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
||||
// we should only guarantee that the head position won't cause out-of-bounds view of the K, V tensors, so
|
||||
// setting it to 0 is the simplest way to achieve that
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13359
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_sbatch llama_kv_cache_unified::sbatch_init(
|
||||
@@ -1712,6 +1719,7 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
||||
|
||||
void llama_kv_cache_recurrent::set_full() {
|
||||
n = size;
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
|
||||
|
||||
+4
-10
@@ -171,11 +171,8 @@ public:
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
@@ -343,11 +340,8 @@ public:
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
|
||||
@@ -469,7 +469,7 @@ llama_model_loader::llama_model_loader(
|
||||
|
||||
meta.reset(gguf_init_from_file(fname.c_str(), params));
|
||||
if (!meta) {
|
||||
throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
|
||||
throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str()));
|
||||
}
|
||||
|
||||
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
|
||||
@@ -528,7 +528,7 @@ llama_model_loader::llama_model_loader(
|
||||
};
|
||||
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
|
||||
if (!ctx_gguf) {
|
||||
throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split));
|
||||
throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split));
|
||||
}
|
||||
|
||||
// check idx
|
||||
@@ -822,13 +822,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
|
||||
mappings.reserve(files.size());
|
||||
mmaps_used.reserve(files.size());
|
||||
for (const auto & file : files) {
|
||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
|
||||
if (!reg) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
bool is_numa = false;
|
||||
|
||||
auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (dev) {
|
||||
auto * reg = ggml_backend_dev_backend_reg(dev);
|
||||
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
||||
if (is_numa_fn) {
|
||||
is_numa = is_numa_fn();
|
||||
}
|
||||
}
|
||||
|
||||
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
||||
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
|
||||
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa);
|
||||
mmaps_used.emplace_back(mapping->size(), 0);
|
||||
if (mlock_mmaps) {
|
||||
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
|
||||
|
||||
+3
-4
@@ -12218,6 +12218,9 @@ struct llm_build_granite : public llm_graph_context {
|
||||
|
||||
// inp_pos - built only if rope enabled
|
||||
ggml_tensor * inp_pos = nullptr;
|
||||
if (use_rope) {
|
||||
inp_pos = build_inp_pos();
|
||||
}
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
@@ -12260,10 +12263,6 @@ struct llm_build_granite : public llm_graph_context {
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
if (use_rope) {
|
||||
|
||||
if (!inp_pos) {
|
||||
inp_pos = build_inp_pos();
|
||||
}
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
|
||||
+13
-11
@@ -14,6 +14,12 @@
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
||||
// Quantization types. Changes to this struct must be replicated in quantize.cpp
|
||||
struct tensor_quantization {
|
||||
std::string name;
|
||||
ggml_type quant = GGML_TYPE_COUNT;
|
||||
};
|
||||
|
||||
static void zeros(std::ofstream & file, size_t n) {
|
||||
char zero = 0;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
@@ -48,12 +54,6 @@ struct quantize_state_impl {
|
||||
{}
|
||||
};
|
||||
|
||||
// changes to this struct must be replicated in quantize.cpp
|
||||
struct tensor_quantization {
|
||||
std::string name;
|
||||
ggml_type quant = GGML_TYPE_COUNT;
|
||||
};
|
||||
|
||||
static void llama_tensor_dequantize_impl(
|
||||
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
||||
const size_t nelements, const int nthread
|
||||
@@ -796,17 +796,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// unless the user specifies a type
|
||||
if (params->tensor_types) {
|
||||
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
||||
const std::string tensor_name(tensor->name);
|
||||
for (const auto & [tname, qtype] : tensor_types) {
|
||||
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
|
||||
if (qtype != new_type) {
|
||||
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
|
||||
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
|
||||
if (qtype != new_type) {
|
||||
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
|
||||
new_type = qtype;
|
||||
break; // if two or more types are specified for the tensor, first match wins
|
||||
}
|
||||
new_type = qtype;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||
new_type = params->token_embedding_type;
|
||||
}
|
||||
|
||||
@@ -140,6 +140,11 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||
struct llama_model_params params) {
|
||||
ggml_time_init();
|
||||
|
||||
if (!params.vocab_only && ggml_backend_reg_count() == 0) {
|
||||
LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned cur_percentage = 0;
|
||||
if (params.progress_callback == NULL) {
|
||||
params.progress_callback_user_data = &cur_percentage;
|
||||
|
||||
@@ -144,6 +144,7 @@ endif()
|
||||
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-regex-partial.cpp)
|
||||
|
||||
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
|
||||
if (NOT WIN32)
|
||||
|
||||
+3
-1
@@ -832,7 +832,9 @@ static void test_template_output_parsers() {
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
||||
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
||||
|
||||
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
// Tests common_regex (esp. its partial final matches support).
|
||||
|
||||
#include "common.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
|
||||
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << " Actual: " << actual << std::endl;
|
||||
std::cerr << std::flush;
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
|
||||
struct test_case {
|
||||
std::string pattern;
|
||||
struct input_output {
|
||||
std::string input;
|
||||
common_regex_match output;
|
||||
};
|
||||
std::vector<input_output> inputs_outputs;
|
||||
};
|
||||
|
||||
static std::string common_regex_match_type_name(common_regex_match_type type) {
|
||||
switch (type) {
|
||||
case COMMON_REGEX_MATCH_TYPE_NONE:
|
||||
return "COMMON_REGEX_MATCH_TYPE_NONE";
|
||||
case COMMON_REGEX_MATCH_TYPE_PARTIAL:
|
||||
return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
|
||||
case COMMON_REGEX_MATCH_TYPE_FULL:
|
||||
return "COMMON_REGEX_MATCH_TYPE_FULL";
|
||||
}
|
||||
return "?";
|
||||
}
|
||||
|
||||
static void test_regex() {
|
||||
printf("[%s]\n", __func__);
|
||||
auto test = [](const test_case & test_case) {
|
||||
common_regex cr(test_case.pattern);
|
||||
std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
|
||||
// std::cout << " partial rev: " << cr.reversed_partial_pattern.str() << '\n';
|
||||
for (const auto & input_output : test_case.inputs_outputs) {
|
||||
std::cout << " Input: " << input_output.input << '\n';
|
||||
auto m = cr.search(input_output.input, 0);
|
||||
if (m != input_output.output) {
|
||||
auto match_to_str = [&](const std::optional<common_regex_match> & m) {
|
||||
std::ostringstream ss;
|
||||
if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
ss << "<no match>";
|
||||
} else {
|
||||
GGML_ASSERT(!input_output.output.groups.empty());
|
||||
std::vector<std::string> parts;
|
||||
for (const auto & g : m->groups) {
|
||||
parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
|
||||
}
|
||||
ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
|
||||
}
|
||||
return ss.str();
|
||||
};
|
||||
std::cout << " Expected: " << match_to_str(input_output.output) << '\n';
|
||||
std::cout << " Got: " << match_to_str(m) << '\n';
|
||||
std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
|
||||
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
};
|
||||
test({
|
||||
"a",
|
||||
{
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
|
||||
{"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
|
||||
{"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"abcd",
|
||||
{
|
||||
{"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||
{"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"d", {}},
|
||||
{"bcd", {}},
|
||||
{"cde", {}},
|
||||
{"cd", {}},
|
||||
{"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
|
||||
{"abbie", {}},
|
||||
{"", {}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
".*?ab",
|
||||
{
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"a.*?b",
|
||||
{
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
{"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||
{"d", {}},
|
||||
{"b", {}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"ab(?:cd){2,4}ef",
|
||||
{
|
||||
// {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||
{"abcde", {}},
|
||||
{"abcdef", {}},
|
||||
{"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
|
||||
{"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
|
||||
{"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
|
||||
{"abcdcdcdcdcdef", {}},
|
||||
{"abcde", {}},
|
||||
{"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"a(?:rte| pure )fact",
|
||||
{
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"fact", {}},
|
||||
{"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
|
||||
{"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
|
||||
{"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
|
||||
{"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
|
||||
{"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
|
||||
{"" , {}},
|
||||
{"pure", {}},
|
||||
{"pure fact", {}},
|
||||
}
|
||||
});
|
||||
test({
|
||||
"abc",
|
||||
{
|
||||
{" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
{" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
|
||||
{"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
|
||||
{"b", {}},
|
||||
{"c", {}},
|
||||
{"", {}},
|
||||
}
|
||||
});
|
||||
|
||||
test({
|
||||
"(?:abc)?\\s*def",
|
||||
{
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
|
||||
{"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
|
||||
{"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||
{"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||
{"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
|
||||
{"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
|
||||
{"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
|
||||
{" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
|
||||
{"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
|
||||
}
|
||||
});
|
||||
|
||||
test({
|
||||
"a+b",
|
||||
{
|
||||
{"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
|
||||
{"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
|
||||
{"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
|
||||
}
|
||||
});
|
||||
|
||||
test({
|
||||
"(?:"
|
||||
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
||||
"(" // match 2 (open_tag)
|
||||
"<tool_call>"
|
||||
"|<function_call>"
|
||||
"|<tool>"
|
||||
"|<tools>"
|
||||
"|<response>"
|
||||
"|<json>"
|
||||
"|<xml>"
|
||||
"|<JSON>"
|
||||
")?"
|
||||
"(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
|
||||
")"
|
||||
"|<function=([^>]+)>" // match 4 (function name)
|
||||
"|<function name=\"([^\"]+)\">", // match 5 (function name again)
|
||||
{
|
||||
{"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
|
||||
{"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
|
||||
{"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
|
||||
{"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
|
||||
{"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
|
||||
{"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
|
||||
{"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
|
||||
{"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
|
||||
{"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
|
||||
{"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
|
||||
{"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void test_regex_to_reversed_partial_regex() {
|
||||
printf("[%s]\n", __func__);
|
||||
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:c)?b)?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("abc"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"(a+)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a+"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"(a*)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a*"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"(a?)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a?"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"([a-z])[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("[a-z]"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"((?:\\w+)?[a-z])[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("[a-z]\\w+"));
|
||||
|
||||
assert_equals<std::string>(
|
||||
"((?:a|b))[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("(?:a|b)"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("abcd"));
|
||||
assert_equals<std::string>(
|
||||
"((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
|
||||
regex_to_reversed_partial_regex("a*b"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:b)?a)?.*)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex(".*?ab"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:b)?.*)?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a.*?b"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a(bc)d"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("a(bc|de)"));
|
||||
assert_equals<std::string>(
|
||||
"((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
|
||||
regex_to_reversed_partial_regex("ab{2,4}c"));
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_regex_to_reversed_partial_regex();
|
||||
test_regex();
|
||||
std::cout << "All tests passed.\n";
|
||||
}
|
||||
@@ -687,7 +687,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
auto value = argv[i];
|
||||
auto * value = argv[i];
|
||||
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||
if (buft_list.empty()) {
|
||||
// enumerate all the devices and add their buffer types to the list
|
||||
@@ -719,7 +719,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
// memory leak present in the implementation
|
||||
// over in arg.cpp. Acceptable because we
|
||||
// only parse these args once in this program.
|
||||
auto override_group = value;
|
||||
auto * override_group = value;
|
||||
if (value[override_group_span_len] == '\0') {
|
||||
value = &value[override_group_span_len];
|
||||
last_group = true;
|
||||
@@ -730,7 +730,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
std::vector<llama_model_tensor_buft_override> group_tensor_buft_overrides{};
|
||||
auto override_span_len = std::strcspn(override_group, ";");
|
||||
while (override_span_len > 0) {
|
||||
auto override = override_group;
|
||||
auto * override = override_group;
|
||||
if (override_group[override_span_len] != '\0') {
|
||||
override_group[override_span_len] = '\0';
|
||||
override_group = &override_group[override_span_len + 1];
|
||||
@@ -743,9 +743,10 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
override[tensor_name_span_len] = '\0';
|
||||
auto tensor_name = override;
|
||||
auto buffer_type = &override[tensor_name_span_len + 1];
|
||||
auto * tensor_name = override;
|
||||
auto * buffer_type = &override[tensor_name_span_len + 1];
|
||||
if (buft_list.find(buffer_type) == buft_list.end()) {
|
||||
printf("error: unrecognized buffer type '%s'\n", buffer_type);
|
||||
printf("Available buffer types:\n");
|
||||
for (const auto & it : buft_list) {
|
||||
printf(" %s\n", ggml_backend_buft_name(it.second));
|
||||
@@ -1736,7 +1737,7 @@ struct sql_printer : public printer {
|
||||
}
|
||||
};
|
||||
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
@@ -1753,14 +1754,19 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
|
||||
for (int i = 1; i < n_tokens; i++) {
|
||||
tokens[i] = std::rand() % n_vocab;
|
||||
}
|
||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
|
||||
int res = llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "%s: failed to decode prompt batch, res = %d\n", __func__, res);
|
||||
return false;
|
||||
}
|
||||
n_processed += n_tokens;
|
||||
}
|
||||
|
||||
llama_synchronize(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
@@ -1770,10 +1776,15 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
|
||||
|
||||
for (int i = 0; i < n_gen; i++) {
|
||||
llama_decode(ctx, llama_batch_get_one(&token, 1));
|
||||
int res = llama_decode(ctx, llama_batch_get_one(&token, 1));
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res);
|
||||
return false;
|
||||
}
|
||||
llama_synchronize(ctx);
|
||||
token = std::rand() % n_vocab;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void llama_null_log_callback(enum ggml_log_level level, const char * text, void * user_data) {
|
||||
@@ -1816,10 +1827,11 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "warning: sanitizer enabled, performance may be affected\n");
|
||||
#endif
|
||||
|
||||
cmd_params params = parse_cmd_params(argc, argv);
|
||||
|
||||
// initialize backends
|
||||
ggml_backend_load_all();
|
||||
|
||||
cmd_params params = parse_cmd_params(argc, argv);
|
||||
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
fprintf(stderr, "%s: error: CPU backend is not loaded\n", __func__);
|
||||
@@ -1917,13 +1929,21 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
|
||||
}
|
||||
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
|
||||
}
|
||||
test_gen(ctx, 1, t.n_threads);
|
||||
bool res = test_gen(ctx, 1, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.reps; i++) {
|
||||
@@ -1934,7 +1954,11 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
||||
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t t_start = get_time_ns();
|
||||
@@ -1944,14 +1968,22 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
test_gen(ctx, t.n_gen, t.n_threads);
|
||||
bool res = test_gen(ctx, t.n_gen, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run gen\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t t_ns = get_time_ns() - t_start;
|
||||
|
||||
@@ -2309,14 +2309,6 @@ struct clip_model_loader {
|
||||
}
|
||||
};
|
||||
|
||||
// read and create ggml_context containing the tensors and their data
|
||||
struct clip_ctx * clip_model_load(const char * fname, const int verbosity) {
|
||||
return clip_init(fname, clip_context_params{
|
||||
/* use_gpu */ true,
|
||||
/* verbosity */ static_cast<ggml_log_level>(verbosity),
|
||||
});
|
||||
}
|
||||
|
||||
struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) {
|
||||
g_logger_state.verbosity_thold = ctx_params.verbosity;
|
||||
clip_ctx * ctx_clip = nullptr;
|
||||
@@ -3085,19 +3077,6 @@ size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_grid_pinpoints.size();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int clip_n_patches(const struct clip_ctx * ctx) {
|
||||
clip_image_f32 img;
|
||||
img.nx = ctx->vision_model.hparams.image_size;
|
||||
img.ny = ctx->vision_model.hparams.image_size;
|
||||
return clip_n_output_tokens(ctx, &img);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||
return clip_n_output_tokens(ctx, img);
|
||||
}
|
||||
|
||||
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||
const auto & params = ctx->vision_model.hparams;
|
||||
const int n_total = clip_n_output_tokens(ctx, img);
|
||||
|
||||
+45
-81
@@ -1,28 +1,9 @@
|
||||
#ifndef CLIP_H
|
||||
#define CLIP_H
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef LLAMA_SHARED
|
||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||
# ifdef LLAMA_BUILD
|
||||
# define CLIP_API __declspec(dllexport)
|
||||
# else
|
||||
# define CLIP_API __declspec(dllimport)
|
||||
# endif
|
||||
# else
|
||||
# define CLIP_API __attribute__ ((visibility ("default")))
|
||||
# endif
|
||||
#else
|
||||
# define CLIP_API
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct clip_ctx;
|
||||
|
||||
struct clip_image_size {
|
||||
@@ -39,97 +20,80 @@ struct clip_context_params {
|
||||
enum ggml_log_level verbosity;
|
||||
};
|
||||
|
||||
// deprecated, use clip_init
|
||||
CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity);
|
||||
struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params);
|
||||
|
||||
CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params);
|
||||
void clip_free(struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
|
||||
|
||||
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
|
||||
|
||||
CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
|
||||
int32_t clip_get_image_size (const struct clip_ctx * ctx);
|
||||
int32_t clip_get_patch_size (const struct clip_ctx * ctx);
|
||||
int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
|
||||
|
||||
// TODO: should be enum, not string
|
||||
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
||||
const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
||||
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
|
||||
const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
||||
size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
|
||||
|
||||
GGML_DEPRECATED(CLIP_API int clip_n_patches(const struct clip_ctx * ctx),
|
||||
"use clip_n_output_tokens instead");
|
||||
GGML_DEPRECATED(CLIP_API int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img),
|
||||
"use clip_n_output_tokens instead");
|
||||
|
||||
CLIP_API int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
|
||||
// for M-RoPE, this will be the number of token positions in X and Y directions
|
||||
// for other models, X will be the total number of tokens and Y will be 1
|
||||
CLIP_API int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
CLIP_API int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
|
||||
// this should be equal to the embedding dimension of the text model
|
||||
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
|
||||
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
|
||||
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
|
||||
int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
|
||||
void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
|
||||
struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
|
||||
|
||||
CLIP_API struct clip_image_size * clip_image_size_init(void);
|
||||
CLIP_API struct clip_image_u8 * clip_image_u8_init (void);
|
||||
CLIP_API struct clip_image_f32 * clip_image_f32_init(void);
|
||||
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
|
||||
struct clip_image_size * clip_image_size_init(void);
|
||||
struct clip_image_u8 * clip_image_u8_init (void);
|
||||
struct clip_image_f32 * clip_image_f32_init(void);
|
||||
struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
|
||||
|
||||
// nx, ny are the output image dimensions
|
||||
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
|
||||
unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
|
||||
|
||||
CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
|
||||
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
|
||||
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
|
||||
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
|
||||
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
|
||||
void clip_image_size_free (struct clip_image_size * img_size);
|
||||
void clip_image_u8_free (struct clip_image_u8 * img);
|
||||
void clip_image_f32_free(struct clip_image_f32 * img);
|
||||
void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
|
||||
void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
|
||||
|
||||
// use for accessing underlay data of clip_image_f32_batch
|
||||
CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
|
||||
CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
|
||||
CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
|
||||
CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
|
||||
size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
|
||||
size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
|
||||
size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
|
||||
struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
|
||||
|
||||
/**
|
||||
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
|
||||
* The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
|
||||
*/
|
||||
CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
|
||||
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
|
||||
|
||||
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||
|
||||
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
|
||||
CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
|
||||
bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
|
||||
|
||||
/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
|
||||
CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
|
||||
bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
|
||||
|
||||
CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
|
||||
struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
|
||||
CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
|
||||
bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
|
||||
bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
|
||||
|
||||
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
|
||||
int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||
bool clip_is_glm(const struct clip_ctx * ctx);
|
||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
||||
bool clip_is_llava(const struct clip_ctx * ctx);
|
||||
bool clip_is_gemma3(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_llava(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CLIP_H
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||
|
||||
@@ -1,636 +0,0 @@
|
||||
#include "arg.h"
|
||||
#include "base64.hpp"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "clip.h"
|
||||
#include "llava.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
#include "ggml-cuda.h"
|
||||
#endif
|
||||
#ifdef NDEBUG
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
#endif
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
// THIS FILE IS ONLY USED FOR TESTING THE QWEN2VL MODEL
|
||||
// IT IS NOT A PRODUCTION CODE
|
||||
|
||||
static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
|
||||
int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) {
|
||||
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
|
||||
const int patch_size = 14 * 2;
|
||||
const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
|
||||
const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
|
||||
auto img_tokens = image_embed->n_image_pos;
|
||||
// llama_pos mrope_pos[img_tokens * 4];
|
||||
std::vector<llama_pos> mrope_pos;
|
||||
mrope_pos.resize(img_tokens * 4);
|
||||
|
||||
for (int y = 0; y < ph; y++)
|
||||
{
|
||||
for (int x = 0; x < pw; x++)
|
||||
{
|
||||
int i = y * pw + x;
|
||||
mrope_pos[i] = *st_pos_id;
|
||||
mrope_pos[i + img_tokens] = *st_pos_id + y;
|
||||
mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
|
||||
mrope_pos[i + img_tokens * 3] = 0;
|
||||
}
|
||||
}
|
||||
*st_pos_id += std::max(pw, ph);
|
||||
|
||||
int processed = 0;
|
||||
std::vector<llama_pos> batch_mrope_pos;
|
||||
batch_mrope_pos.resize(img_tokens * 4);
|
||||
|
||||
for (int i = 0; i < img_tokens; i += n_batch) {
|
||||
int n_eval = img_tokens - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
|
||||
// llama_pos batch_mrope_pos[n_eval * 4];
|
||||
std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0);
|
||||
memcpy(batch_mrope_pos.data(), &mrope_pos[processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
||||
|
||||
llama_batch batch = {
|
||||
int32_t(n_eval), // n_tokens
|
||||
nullptr, // token
|
||||
(image_embed->embed+i*n_embd), // embed
|
||||
batch_mrope_pos.data(), // pos
|
||||
nullptr, // n_seq_id
|
||||
nullptr, // seq_id
|
||||
nullptr, // logits
|
||||
};
|
||||
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
processed += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past, int * st_pos_id) {
|
||||
int N = (int) tokens.size();
|
||||
for (int i = 0; i < N; i += n_batch) {
|
||||
int n_eval = (int) tokens.size() - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
auto batch = llama_batch_get_one(&tokens[i], n_eval);
|
||||
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
*st_pos_id += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past, int * st_pos_id) {
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(id);
|
||||
return eval_tokens(ctx_llama, tokens, 1, n_past, st_pos_id);
|
||||
}
|
||||
|
||||
static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, int * st_pos_id, bool add_bos){
|
||||
std::string str2 = str;
|
||||
std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
|
||||
eval_tokens(ctx_llama, embd_inp, n_batch, n_past, st_pos_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
static const char * sample(struct common_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past, int * st_pos_id) {
|
||||
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx_llama);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
static std::string ret;
|
||||
if (llama_vocab_is_eog(vocab, id)) {
|
||||
ret = "</s>";
|
||||
} else {
|
||||
ret = common_token_to_piece(ctx_llama, id);
|
||||
}
|
||||
eval_id(ctx_llama, id, n_past, st_pos_id);
|
||||
return ret.c_str();
|
||||
}
|
||||
|
||||
static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
|
||||
static const char* IMG_BASE64_TAG_END = "\">";
|
||||
|
||||
static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
|
||||
begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
|
||||
end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
|
||||
}
|
||||
|
||||
static bool prompt_contains_image(const std::string& prompt) {
|
||||
size_t begin, end;
|
||||
find_image_tag_in_prompt(prompt, begin, end);
|
||||
return (begin != std::string::npos);
|
||||
}
|
||||
|
||||
// replaces the base64 image tag in the prompt with `replacement`
|
||||
static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
|
||||
size_t img_base64_str_start, img_base64_str_end;
|
||||
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
|
||||
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
|
||||
LOG_ERR("%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
|
||||
auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
|
||||
auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
|
||||
|
||||
auto required_bytes = base64::required_encode_size(base64_str.size());
|
||||
auto img_bytes = std::vector<unsigned char>(required_bytes);
|
||||
base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
|
||||
|
||||
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
|
||||
if (!embed) {
|
||||
LOG_ERR("%s: could not load image from base64 string.\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
|
||||
size_t begin, end;
|
||||
find_image_tag_in_prompt(prompt, begin, end);
|
||||
if (begin == std::string::npos || end == std::string::npos) {
|
||||
return prompt;
|
||||
}
|
||||
auto pre = prompt.substr(0, begin);
|
||||
auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
|
||||
return pre + replacement + post;
|
||||
}
|
||||
|
||||
struct llava_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
struct llama_context * ctx_llama = NULL;
|
||||
struct llama_model * model = NULL;
|
||||
};
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
LOG("\n example usage:\n");
|
||||
LOG("\n %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
|
||||
LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||
}
|
||||
|
||||
static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) {
|
||||
|
||||
// load and preprocess the image
|
||||
llava_image_embed * embed = NULL;
|
||||
auto prompt = params->prompt;
|
||||
if (prompt_contains_image(prompt)) {
|
||||
if (!params->image.empty()) {
|
||||
LOG_INF("using base64 encoded image instead of command line image path\n");
|
||||
}
|
||||
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt);
|
||||
if (!embed) {
|
||||
LOG_ERR("%s: can't load image from prompt\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
params->prompt = remove_image_from_prompt(prompt);
|
||||
} else {
|
||||
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str());
|
||||
if (!embed) {
|
||||
fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) {
|
||||
int n_past = 0;
|
||||
int cur_pos_id = 0;
|
||||
|
||||
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
|
||||
|
||||
std::string system_prompt, user_prompt;
|
||||
size_t image_pos = prompt.find("<|vision_start|>");
|
||||
if (image_pos != std::string::npos) {
|
||||
// new templating mode: Provide the full prompt including system message and use <image> as a placeholder for the image
|
||||
system_prompt = prompt.substr(0, image_pos);
|
||||
user_prompt = prompt.substr(image_pos + std::string("<|vision_pad|>").length());
|
||||
LOG_INF("system_prompt: %s\n", system_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
LOG_INF("user_prompt: %s\n", user_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// llava-1.5 native mode
|
||||
system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|>";
|
||||
user_prompt = "<|vision_end|>" + prompt + "<|im_end|>\n<|im_start|>assistant\n";
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, true);
|
||||
if (image_embed != nullptr) {
|
||||
auto image_size = clip_get_load_image_size(ctx_llava->ctx_clip);
|
||||
qwen2vl_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past, &cur_pos_id, image_size);
|
||||
}
|
||||
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, false);
|
||||
|
||||
// generate the response
|
||||
|
||||
LOG("\n");
|
||||
|
||||
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
|
||||
if (!smpl) {
|
||||
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past, &cur_pos_id);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
LOG("%s", tmp);
|
||||
if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
|
||||
if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
|
||||
if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
common_sampler_free(smpl);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
static struct llama_model * llava_init(common_params * params) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params->numa);
|
||||
|
||||
llama_model_params model_params = common_model_params_to_llama(*params);
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
|
||||
const char * clip_path = params->mmproj.path.c_str();
|
||||
|
||||
auto prompt = params->prompt;
|
||||
if (prompt.empty()) {
|
||||
prompt = "describe the image in detail.";
|
||||
}
|
||||
|
||||
auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
|
||||
|
||||
llama_context_params ctx_params = common_context_params_to_llama(*params);
|
||||
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
|
||||
|
||||
llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
|
||||
|
||||
if (ctx_llama == NULL) {
|
||||
LOG_ERR("%s: failed to create the llama_context\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
||||
|
||||
ctx_llava->ctx_llama = ctx_llama;
|
||||
ctx_llava->ctx_clip = ctx_clip;
|
||||
ctx_llava->model = model;
|
||||
return ctx_llava;
|
||||
}
|
||||
|
||||
static void llava_free(struct llava_context * ctx_llava) {
|
||||
if (ctx_llava->ctx_clip) {
|
||||
clip_free(ctx_llava->ctx_clip);
|
||||
ctx_llava->ctx_clip = NULL;
|
||||
}
|
||||
|
||||
llama_free(ctx_llava->ctx_llama);
|
||||
llama_model_free(ctx_llava->model);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
|
||||
static void debug_test_mrope_2d() {
|
||||
// 1. Initialize backend
|
||||
ggml_backend_t backend = NULL;
|
||||
std::string backend_name = "";
|
||||
// #ifdef GGML_USE_CUDA
|
||||
// fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||
// backend = ggml_backend_cuda_init(0); // init device 0
|
||||
// backend_name = "cuda";
|
||||
// if (!backend) {
|
||||
// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||
// }
|
||||
// #endif
|
||||
// if there aren't GPU Backends fallback to CPU backend
|
||||
if (!backend) {
|
||||
backend = ggml_backend_cpu_init();
|
||||
backend_name = "cpu";
|
||||
}
|
||||
|
||||
// Calculate the size needed to allocate
|
||||
size_t ctx_size = 0;
|
||||
ctx_size += 2 * ggml_tensor_overhead(); // tensors
|
||||
// no need to allocate anything else!
|
||||
|
||||
// 2. Allocate `ggml_context` to store tensor data
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
|
||||
};
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 12, 30);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 30 * 4);
|
||||
ggml_set_name(pos, "pos");
|
||||
ggml_set_input(pos);
|
||||
|
||||
std::vector<float> dummy_q;
|
||||
dummy_q.resize(128 * 12 * 30);
|
||||
std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
|
||||
// memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
|
||||
|
||||
std::vector<int> pos_id;
|
||||
pos_id.resize(30 * 4);
|
||||
for (int i = 0; i < 30; i ++) {
|
||||
pos_id[i] = i;
|
||||
pos_id[i + 30] = i + 10;
|
||||
pos_id[i + 60] = i + 20;
|
||||
pos_id[i + 90] = i + 30;
|
||||
}
|
||||
int sections[4] = {32, 32, 0, 0};
|
||||
|
||||
// 4. Allocate a `ggml_backend_buffer` to store all tensors
|
||||
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
||||
|
||||
// 5. Copy tensor data from main memory (RAM) to backend buffer
|
||||
ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw));
|
||||
ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos));
|
||||
|
||||
// 6. Create a `ggml_cgraph` for mul_mat operation
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
struct ggml_context * ctx_cgraph = NULL;
|
||||
|
||||
// create a temporally context to build the graph
|
||||
struct ggml_init_params params0 = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
|
||||
};
|
||||
ctx_cgraph = ggml_init(params0);
|
||||
gf = ggml_new_graph(ctx_cgraph);
|
||||
|
||||
struct ggml_tensor * result0 = ggml_rope_multi(
|
||||
ctx_cgraph, inp_raw, pos, nullptr,
|
||||
128/2, sections, LLAMA_ROPE_TYPE_VISION, 32768, 1000000, 1,
|
||||
0, 1, 32, 1);
|
||||
|
||||
// Add "result" tensor and all of its dependencies to the cgraph
|
||||
ggml_build_forward_expand(gf, result0);
|
||||
|
||||
// 7. Create a `ggml_gallocr` for cgraph computation
|
||||
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
|
||||
ggml_gallocr_alloc_graph(allocr, gf);
|
||||
|
||||
// 9. Run the computation
|
||||
int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_cpu_set_n_threads(backend, n_threads);
|
||||
}
|
||||
ggml_backend_graph_compute(backend, gf);
|
||||
|
||||
// 10. Retrieve results (output tensors)
|
||||
// in this example, output tensor is always the last tensor in the graph
|
||||
struct ggml_tensor * result = result0;
|
||||
// struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1];
|
||||
float * result_data = (float *)malloc(ggml_nbytes(result));
|
||||
// because the tensor data is stored in device buffer, we need to copy it back to RAM
|
||||
ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
|
||||
const std::string bin_file = "mrope_2d_" + backend_name +".bin";
|
||||
std::ofstream outFile(bin_file, std::ios::binary);
|
||||
|
||||
if (outFile.is_open()) {
|
||||
outFile.write(reinterpret_cast<const char*>(result_data), ggml_nbytes(result));
|
||||
outFile.close();
|
||||
std::cout << "Data successfully written to " + bin_file << std::endl;
|
||||
} else {
|
||||
std::cerr << "Error opening file!" << std::endl;
|
||||
}
|
||||
|
||||
free(result_data);
|
||||
// 11. Free memory and exit
|
||||
ggml_free(ctx_cgraph);
|
||||
ggml_gallocr_free(allocr);
|
||||
ggml_free(ctx);
|
||||
ggml_backend_buffer_free(buffer);
|
||||
ggml_backend_free(backend);
|
||||
}
|
||||
|
||||
enum model_output_type {
|
||||
conv3d,
|
||||
patch_embed,
|
||||
patch_win_attn_scatter,
|
||||
first_attn_layer,
|
||||
last_attn_layer,
|
||||
attn_softmax,
|
||||
final_layer,
|
||||
};
|
||||
|
||||
static void debug_dump_img_embed(struct llava_context * ctx_llava, model_output_type output_type) {
|
||||
constexpr int ih = 140;
|
||||
constexpr int iw = 196;
|
||||
// constexpr int ih = 56;
|
||||
// constexpr int iw = 56;
|
||||
// int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
||||
int n_embd = 1280;
|
||||
int merge = 1;
|
||||
if (output_type == model_output_type::final_layer) {
|
||||
n_embd = 2048;
|
||||
merge = 2;
|
||||
}
|
||||
else if (output_type == model_output_type::attn_softmax) {
|
||||
merge = 1;
|
||||
n_embd = (ih/14/merge) * (iw/14/merge) * 16;
|
||||
}
|
||||
|
||||
int ne = (ih/14/merge) * (iw/14/merge) * n_embd;
|
||||
float vals[iw * ih * 3];
|
||||
// float embd[ne];
|
||||
std::vector<float> embd;
|
||||
embd.resize(ne);
|
||||
|
||||
for (int i = 0; i < iw*ih; i++)
|
||||
{
|
||||
for (int c = 0; c < 3; c++)
|
||||
vals[i * 3 + c] = (float)i / (iw*ih);
|
||||
}
|
||||
|
||||
clip_encode_float_image(ctx_llava->ctx_clip, 8, vals, ih, iw, embd.data());
|
||||
|
||||
std::string file_postfix = "";
|
||||
switch (output_type)
|
||||
{
|
||||
case model_output_type::conv3d:
|
||||
file_postfix = "conv3d";
|
||||
break;
|
||||
case model_output_type::patch_embed:
|
||||
file_postfix = "patch_embed";
|
||||
break;
|
||||
case model_output_type::patch_win_attn_scatter:
|
||||
file_postfix = "scatter";
|
||||
break;
|
||||
case model_output_type::first_attn_layer:
|
||||
file_postfix = "first_attn";
|
||||
break;
|
||||
case model_output_type::last_attn_layer:
|
||||
file_postfix = "last_attn";
|
||||
break;
|
||||
case model_output_type::attn_softmax:
|
||||
file_postfix = "attn_softmax";
|
||||
break;
|
||||
case model_output_type::final_layer:
|
||||
file_postfix = "final";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
auto output_path = "img_embed_" + file_postfix + ".bin";
|
||||
|
||||
std::ofstream outFile(output_path, std::ios::binary);
|
||||
if (outFile.is_open()) {
|
||||
outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float));
|
||||
|
||||
outFile.close();
|
||||
std::cout << "Data successfully written to ::[ " << output_path << std::endl;
|
||||
} else {
|
||||
std::cerr << "Error opening file!" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto * model = llava_init(¶ms);
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (prompt_contains_image(params.prompt)) {
|
||||
auto * ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto * image_embed = load_image(ctx_llava, ¶ms, "");
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
#ifndef NDEBUG
|
||||
} else if (params.image[0].empty()) {
|
||||
auto ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
// debug_test_mrope_2d();
|
||||
debug_dump_img_embed(ctx_llava, model_output_type::final_layer);
|
||||
// debug_dump_img_embed(ctx_llava, model_output_type::last_attn_layer);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
#endif
|
||||
} else {
|
||||
for (auto & image : params.image) {
|
||||
auto * ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto * image_embed = load_image(ctx_llava, ¶ms, image);
|
||||
if (!image_embed) {
|
||||
LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
}
|
||||
}
|
||||
|
||||
llama_model_free(model);
|
||||
|
||||
return 0;
|
||||
}
|
||||
+13
-76
@@ -57,6 +57,12 @@ static const std::vector<quant_option> QUANT_OPTIONS = {
|
||||
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
|
||||
};
|
||||
|
||||
// Quantization types. Changes to this struct must be replicated in llama-quantize.cpp
|
||||
struct tensor_quantization {
|
||||
std::string name;
|
||||
ggml_type quant = GGML_TYPE_COUNT;
|
||||
};
|
||||
|
||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
|
||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset";
|
||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
|
||||
@@ -244,56 +250,10 @@ static ggml_type parse_ggml_type(const char * arg) {
|
||||
return type;
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg);
|
||||
fprintf(stderr, "\n%s: invalid ggml_type '%s'\n\n", __func__, arg);
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
// Allowed tensors for arbitrary quantization with --tensor-type option
|
||||
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
|
||||
"attn_k",
|
||||
"attn_kv_a_mqa",
|
||||
"attn_kv_b",
|
||||
"attn_o",
|
||||
"attn_output",
|
||||
"attn_q",
|
||||
"attn_q_a",
|
||||
"attn_q_b",
|
||||
"attn_qkv",
|
||||
"attn_v",
|
||||
"channel_mix_key",
|
||||
"channel_mix_receptance",
|
||||
"channel_mix_value",
|
||||
"cls",
|
||||
"cls.output",
|
||||
"cross_attn_k",
|
||||
"cross_attn_o",
|
||||
"cross_attn_q",
|
||||
"cross_attn_v",
|
||||
"ffn_act",
|
||||
"ffn_down",
|
||||
"ffn_down_exps",
|
||||
"ffn_down_shexp",
|
||||
"ffn_gate",
|
||||
"ffn_gate_exps",
|
||||
"ffn_gate_shexp",
|
||||
"ffn_up",
|
||||
"ffn_up_exps",
|
||||
"ffn_up_shexp",
|
||||
"ssm_in",
|
||||
"ssm_out",
|
||||
"time_mix_gate",
|
||||
"time_mix_key",
|
||||
"time_mix_output",
|
||||
"time_mix_receptance",
|
||||
"time_mix_value",
|
||||
};
|
||||
|
||||
// changes to this struct must be replicated in llama-quant.cpp
|
||||
struct tensor_quantization {
|
||||
std::string name;
|
||||
ggml_type quant = GGML_TYPE_COUNT;
|
||||
};
|
||||
|
||||
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
|
||||
const char * sep = strchr(data, '=');
|
||||
if (sep == nullptr) {
|
||||
@@ -306,7 +266,6 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
|
||||
printf("\n%s: missing tensor name\n\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const size_t qt_len = strlen(sep); qt_len == 1) {
|
||||
printf("\n%s: missing quantization type\n\n", __func__);
|
||||
return false;
|
||||
@@ -315,37 +274,15 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
|
||||
std::string tn(data, tn_len);
|
||||
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
|
||||
sep++;
|
||||
const std::string qt(sep);
|
||||
|
||||
bool found = false;
|
||||
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
|
||||
std::string tensor;
|
||||
tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn;
|
||||
// handle special case of cls.output
|
||||
std::string cls_output = "cls.output";
|
||||
if (tn.find(cls_output) != std::string::npos) {
|
||||
tensor = "cls.output";
|
||||
}
|
||||
// check if an allowed tensor exists and it's at the end of the kv string
|
||||
if (tensor == allowed) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) {
|
||||
printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
tensor_quantization tqz;
|
||||
tqz.name = tn;
|
||||
tqz.quant = parse_ggml_type(qt.c_str());
|
||||
tqz.quant = parse_ggml_type(sep);
|
||||
tensor_type.emplace_back(std::move(tqz));
|
||||
if (tqz.quant == GGML_TYPE_COUNT) {
|
||||
printf("\n%s: invalid quantization type '%s'\n\n", __func__, sep);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -1040,7 +1040,7 @@ To know the `id` of the adapter, use GET `/lora-adapters`
|
||||
|
||||
Returns information about the loaded model. See [OpenAI Models API documentation](https://platform.openai.com/docs/api-reference/models).
|
||||
|
||||
The returned list always has one single element.
|
||||
The returned list always has one single element. The `meta` field can be `null` (for example, while the model is still loading).
|
||||
|
||||
By default, model `id` field is the path to model file, specified via `-m`. You can set a custom value for model `id` field via `--alias` argument. For example, `--alias gpt-4o-mini`.
|
||||
|
||||
|
||||
Binary file not shown.
+24
-13
@@ -1429,7 +1429,7 @@ struct server_slot {
|
||||
pos = text.find(word, from_pos);
|
||||
} else {
|
||||
// otherwise, partial stop
|
||||
pos = find_partial_stop_string(word, text);
|
||||
pos = string_find_partial_stop(text, word);
|
||||
}
|
||||
|
||||
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
||||
@@ -2951,7 +2951,8 @@ struct server_context {
|
||||
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// add generated tokens to cache
|
||||
{
|
||||
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
|
||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||
new_tokens[i - n_discard] = new_tokens[i];
|
||||
@@ -2996,10 +2997,7 @@ struct server_context {
|
||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
||||
|
||||
slot.n_past += 1;
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
}
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
|
||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
|
||||
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
|
||||
@@ -3171,6 +3169,11 @@ struct server_context {
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||
}
|
||||
} else {
|
||||
// if we don't cache the prompt, we have to remove the entire KV cache
|
||||
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
|
||||
slot.n_past = 0;
|
||||
slot.cache_tokens.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3204,7 +3207,7 @@ struct server_context {
|
||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||
|
||||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
slot.cache_tokens.keep_first(slot.n_past);
|
||||
|
||||
// check if we should process the image
|
||||
if (slot.n_past < slot.n_prompt_tokens
|
||||
@@ -3221,7 +3224,8 @@ struct server_context {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// add the image chunk to cache
|
||||
{
|
||||
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
|
||||
slot.cache_tokens.push_back(chunk.get()); // copy
|
||||
}
|
||||
@@ -3242,9 +3246,7 @@ struct server_context {
|
||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
}
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot.n_past++;
|
||||
@@ -3705,6 +3707,9 @@ int main(int argc, char ** argv) {
|
||||
if (req.path == "/" || tmp.back() == "html") {
|
||||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||
res.status = 503;
|
||||
} else if (req.path == "/models" || req.path == "/v1/models") {
|
||||
// allow the models endpoint to be accessed during loading
|
||||
return true;
|
||||
} else {
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
}
|
||||
@@ -4363,7 +4368,13 @@ int main(int argc, char ** argv) {
|
||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
const auto handle_models = [¶ms, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||
server_state current_state = state.load();
|
||||
json model_meta = nullptr;
|
||||
if (current_state == SERVER_STATE_READY) {
|
||||
model_meta = ctx_server.model_meta();
|
||||
}
|
||||
|
||||
json models = {
|
||||
{"object", "list"},
|
||||
{"data", {
|
||||
@@ -4372,7 +4383,7 @@ int main(int argc, char ** argv) {
|
||||
{"object", "model"},
|
||||
{"created", std::time(0)},
|
||||
{"owned_by", "llamacpp"},
|
||||
{"meta", ctx_server.model_meta()}
|
||||
{"meta", model_meta},
|
||||
},
|
||||
}}
|
||||
};
|
||||
|
||||
@@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt():
|
||||
assert res_cache.body["content"] == res_no_cache.body["content"]
|
||||
|
||||
|
||||
def test_nocache_long_input_prompt():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is"*32,
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_completion_with_tokens_input():
|
||||
global server
|
||||
server.temperature = 0.0
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
import pytest
|
||||
|
||||
# ensure grandparent path is in sys.path
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from unit.test_tool_call import TEST_TOOL
|
||||
path = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(path))
|
||||
|
||||
import datetime
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
TIMEOUT_SERVER_START = 15*60
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2"
|
||||
server.server_port = 8081
|
||||
server.n_slots = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||
@pytest.mark.parametrize("template_name,format", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
|
||||
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
|
||||
])
|
||||
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"tools": tools,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
today_str = datetime.date.today().strftime(format)
|
||||
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
|
||||
@@ -109,7 +109,7 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||
global server
|
||||
n_predict = 512
|
||||
n_predict = 1024
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
|
||||
+13
-1
@@ -643,6 +643,18 @@ static json oaicompat_completion_params_parse(
|
||||
throw std::runtime_error("Expected 'messages' to be an array");
|
||||
}
|
||||
for (auto & msg : messages) {
|
||||
std::string role = json_value(msg, "role", std::string());
|
||||
if (role != "assistant" && !msg.contains("content")) {
|
||||
throw std::runtime_error("All non-assistant messages must contain 'content'");
|
||||
}
|
||||
if (role == "assistant") {
|
||||
if (!msg.contains("content") && !msg.contains("tool_calls")) {
|
||||
throw std::runtime_error("Assistant message must contain either 'content' or 'tool_calls'!");
|
||||
}
|
||||
if (!msg.contains("content")) {
|
||||
continue; // avoid errors with no content
|
||||
}
|
||||
}
|
||||
json & content = msg.at("content");
|
||||
if (content.is_string() || content.is_null()) {
|
||||
continue;
|
||||
@@ -1153,7 +1165,7 @@ public:
|
||||
tokens.clear();
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
void keep_first(size_t n) {
|
||||
GGML_ASSERT(n <= tokens.size());
|
||||
if (has_mtmd) {
|
||||
// we throw an error if we try to remove a token in the middle of an image
|
||||
|
||||
Generated
+280
-6
@@ -18,6 +18,7 @@
|
||||
"dexie": "^4.0.11",
|
||||
"highlight.js": "^11.10.0",
|
||||
"katex": "^0.16.15",
|
||||
"pdfjs-dist": "^5.2.133",
|
||||
"postcss": "^8.4.49",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
@@ -44,6 +45,7 @@
|
||||
"eslint": "^9.17.0",
|
||||
"eslint-plugin-react-hooks": "^5.0.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.16",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^15.14.0",
|
||||
"prettier": "^3.4.2",
|
||||
"sass-embedded": "^1.83.4",
|
||||
@@ -987,7 +989,7 @@
|
||||
"version": "0.3.8",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz",
|
||||
"integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==",
|
||||
"dev": true,
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@jridgewell/set-array": "^1.2.1",
|
||||
@@ -1002,7 +1004,7 @@
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz",
|
||||
"integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==",
|
||||
"dev": true,
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6.0.0"
|
||||
@@ -1012,30 +1014,224 @@
|
||||
"version": "1.2.1",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz",
|
||||
"integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==",
|
||||
"dev": true,
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@jridgewell/source-map": {
|
||||
"version": "0.3.6",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz",
|
||||
"integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==",
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/gen-mapping": "^0.3.5",
|
||||
"@jridgewell/trace-mapping": "^0.3.25"
|
||||
}
|
||||
},
|
||||
"node_modules/@jridgewell/sourcemap-codec": {
|
||||
"version": "1.5.0",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz",
|
||||
"integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==",
|
||||
"dev": true,
|
||||
"devOptional": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@jridgewell/trace-mapping": {
|
||||
"version": "0.3.25",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz",
|
||||
"integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==",
|
||||
"dev": true,
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@jridgewell/resolve-uri": "^3.1.0",
|
||||
"@jridgewell/sourcemap-codec": "^1.4.14"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.70.tgz",
|
||||
"integrity": "sha512-nD6NGa4JbNYSZYsTnLGrqe9Kn/lCkA4ybXt8sx5ojDqZjr2i0TWAHxx/vhgfjX+i3hCdKWufxYwi7CfXqtITSA==",
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@napi-rs/canvas-android-arm64": "0.1.70",
|
||||
"@napi-rs/canvas-darwin-arm64": "0.1.70",
|
||||
"@napi-rs/canvas-darwin-x64": "0.1.70",
|
||||
"@napi-rs/canvas-linux-arm-gnueabihf": "0.1.70",
|
||||
"@napi-rs/canvas-linux-arm64-gnu": "0.1.70",
|
||||
"@napi-rs/canvas-linux-arm64-musl": "0.1.70",
|
||||
"@napi-rs/canvas-linux-riscv64-gnu": "0.1.70",
|
||||
"@napi-rs/canvas-linux-x64-gnu": "0.1.70",
|
||||
"@napi-rs/canvas-linux-x64-musl": "0.1.70",
|
||||
"@napi-rs/canvas-win32-x64-msvc": "0.1.70"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-android-arm64": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.70.tgz",
|
||||
"integrity": "sha512-I/YOuQ0wbkVYxVaYtCgN42WKTYxNqFA0gTcTrHIGG1jfpDSyZWII/uHcjOo4nzd19io6Y4+/BqP8E5hJgf9OmQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"android"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-darwin-arm64": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.70.tgz",
|
||||
"integrity": "sha512-4pPGyXetHIHkw2TOJHujt3mkCP8LdDu8+CT15ld9Id39c752RcI0amDHSuMLMQfAjvusA9B5kKxazwjMGjEJpQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-darwin-x64": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.70.tgz",
|
||||
"integrity": "sha512-+2N6Os9LbkmDMHL+raknrUcLQhsXzc5CSXRbXws9C3pv/mjHRVszQ9dhFUUe9FjfPhCJznO6USVdwOtu7pOrzQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-arm-gnueabihf": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.70.tgz",
|
||||
"integrity": "sha512-QjscX9OaKq/990sVhSMj581xuqLgiaPVMjjYvWaCmAJRkNQ004QfoSMEm3FoTqM4DRoquP8jvuEXScVJsc1rqQ==",
|
||||
"cpu": [
|
||||
"arm"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-arm64-gnu": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.70.tgz",
|
||||
"integrity": "sha512-LNakMOwwqwiHIwMpnMAbFRczQMQ7TkkMyATqFCOtUJNlE6LPP/QiUj/mlFrNbUn/hctqShJ60gWEb52ZTALbVw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-arm64-musl": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.70.tgz",
|
||||
"integrity": "sha512-wBTOllEYNfJCHOdZj9v8gLzZ4oY3oyPX8MSRvaxPm/s7RfEXxCyZ8OhJ5xAyicsDdbE5YBZqdmaaeP5+xKxvtg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-riscv64-gnu": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.70.tgz",
|
||||
"integrity": "sha512-GVUUPC8TuuFqHip0rxHkUqArQnlzmlXmTEBuXAWdgCv85zTCFH8nOHk/YCF5yo0Z2eOm8nOi90aWs0leJ4OE5Q==",
|
||||
"cpu": [
|
||||
"riscv64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-x64-gnu": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.70.tgz",
|
||||
"integrity": "sha512-/kvUa2lZRwGNyfznSn5t1ShWJnr/m5acSlhTV3eXECafObjl0VBuA1HJw0QrilLpb4Fe0VLywkpD1NsMoVDROQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-linux-x64-musl": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.70.tgz",
|
||||
"integrity": "sha512-aqlv8MLpycoMKRmds7JWCfVwNf1fiZxaU7JwJs9/ExjTD8lX2KjsO7CTeAj5Cl4aEuzxUWbJPUUE2Qu9cZ1vfg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@napi-rs/canvas-win32-x64-msvc": {
|
||||
"version": "0.1.70",
|
||||
"resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.70.tgz",
|
||||
"integrity": "sha512-Q9QU3WIpwBTVHk4cPfBjGHGU4U0llQYRXgJtFtYqqGNEOKVN4OT6PQ+ve63xwIPODMpZ0HHyj/KLGc9CWc3EtQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@nodelib/fs.scandir": {
|
||||
"version": "2.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
|
||||
@@ -2001,7 +2197,7 @@
|
||||
"version": "8.14.0",
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz",
|
||||
"integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==",
|
||||
"dev": true,
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
@@ -2185,6 +2381,14 @@
|
||||
"devOptional": true,
|
||||
"license": "MIT/X11"
|
||||
},
|
||||
"node_modules/buffer-from": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz",
|
||||
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/callsites": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
|
||||
@@ -2802,6 +3006,13 @@
|
||||
"reusify": "^1.0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/fflate": {
|
||||
"version": "0.8.2",
|
||||
"resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz",
|
||||
"integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/file-entry-cache": {
|
||||
"version": "8.0.0",
|
||||
"resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz",
|
||||
@@ -4835,6 +5046,18 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/pdfjs-dist": {
|
||||
"version": "5.2.133",
|
||||
"resolved": "https://registry.npmjs.org/pdfjs-dist/-/pdfjs-dist-5.2.133.tgz",
|
||||
"integrity": "sha512-abE6ZWDxztt+gGFzfm4bX2ggfxUk9wsDEoFzIJm9LozaY3JdXR7jyLK4Bjs+XLXplCduuWS1wGhPC4tgTn/kzg==",
|
||||
"license": "Apache-2.0",
|
||||
"engines": {
|
||||
"node": ">=20.16.0 || >=22.3.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@napi-rs/canvas": "^0.1.67"
|
||||
}
|
||||
},
|
||||
"node_modules/picocolors": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz",
|
||||
@@ -5745,6 +5968,17 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/source-map": {
|
||||
"version": "0.6.1",
|
||||
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
|
||||
"integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==",
|
||||
"license": "BSD-3-Clause",
|
||||
"optional": true,
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/source-map-js": {
|
||||
"version": "1.2.1",
|
||||
"resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz",
|
||||
@@ -5754,6 +5988,18 @@
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/source-map-support": {
|
||||
"version": "0.5.21",
|
||||
"resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz",
|
||||
"integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==",
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"buffer-from": "^1.0.0",
|
||||
"source-map": "^0.6.0"
|
||||
}
|
||||
},
|
||||
"node_modules/space-separated-tokens": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz",
|
||||
@@ -5851,6 +6097,34 @@
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/terser": {
|
||||
"version": "5.39.1",
|
||||
"resolved": "https://registry.npmjs.org/terser/-/terser-5.39.1.tgz",
|
||||
"integrity": "sha512-Mm6+uad0ZuDtcV8/4uOZQDQ8RuiC5Pu+iZRedJtF7yA/27sPL7d++In/AJKpWZlU3SYMPPkVfwetn6sgZ66pUA==",
|
||||
"license": "BSD-2-Clause",
|
||||
"optional": true,
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/source-map": "^0.3.3",
|
||||
"acorn": "^8.8.2",
|
||||
"commander": "^2.20.0",
|
||||
"source-map-support": "~0.5.20"
|
||||
},
|
||||
"bin": {
|
||||
"terser": "bin/terser"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/terser/node_modules/commander": {
|
||||
"version": "2.20.3",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz",
|
||||
"integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==",
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/textlinestream": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/textlinestream/-/textlinestream-1.1.1.tgz",
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "tsc -b && vite build",
|
||||
"build": "npm run format && tsc -b && vite build",
|
||||
"format": "eslint . && prettier --write .",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview"
|
||||
@@ -21,6 +21,7 @@
|
||||
"dexie": "^4.0.11",
|
||||
"highlight.js": "^11.10.0",
|
||||
"katex": "^0.16.15",
|
||||
"pdfjs-dist": "^5.2.133",
|
||||
"postcss": "^8.4.49",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
@@ -47,6 +48,7 @@
|
||||
"eslint": "^9.17.0",
|
||||
"eslint-plugin-react-hooks": "^5.0.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.16",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^15.14.0",
|
||||
"prettier": "^3.4.2",
|
||||
"sass-embedded": "^1.83.4",
|
||||
|
||||
@@ -16,6 +16,8 @@ export const CONFIG_DEFAULT = {
|
||||
showTokensPerSecond: false,
|
||||
showThoughtInProgress: false,
|
||||
excludeThoughtOnReq: true,
|
||||
pasteLongTextToFileLen: 2500,
|
||||
pdfAsImage: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'edkypmxt',
|
||||
temperature: 0.8,
|
||||
@@ -43,6 +45,8 @@ export const CONFIG_DEFAULT = {
|
||||
export const CONFIG_INFO: Record<string, string> = {
|
||||
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
||||
systemMessage: 'The starting message that defines how model should behave.',
|
||||
pasteLongTextToFileLen:
|
||||
'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.',
|
||||
samplers:
|
||||
'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature',
|
||||
temperature:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { ClipboardEvent, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
||||
import ChatMessage from './ChatMessage';
|
||||
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
||||
@@ -306,6 +306,7 @@ function ChatInput({
|
||||
onStop: () => void;
|
||||
isGenerating: boolean;
|
||||
}) {
|
||||
const { config } = useAppContext();
|
||||
const [isDrag, setIsDrag] = useState(false);
|
||||
|
||||
return (
|
||||
@@ -328,6 +329,38 @@ function ChatInput({
|
||||
{({ getRootProps, getInputProps }) => (
|
||||
<div
|
||||
className="flex flex-col rounded-xl border-1 border-base-content/30 p-3 w-full"
|
||||
// when a file is pasted to the input, we handle it here
|
||||
// if a text is pasted, and if it is long text, we will convert it to a file
|
||||
onPasteCapture={(e: ClipboardEvent<HTMLInputElement>) => {
|
||||
const text = e.clipboardData.getData('text/plain');
|
||||
if (
|
||||
text.length > 0 &&
|
||||
config.pasteLongTextToFileLen > 0 &&
|
||||
text.length > config.pasteLongTextToFileLen
|
||||
) {
|
||||
// if the text is too long, we will convert it to a file
|
||||
extraContext.addItems([
|
||||
{
|
||||
type: 'context',
|
||||
name: 'Pasted Content',
|
||||
content: text,
|
||||
},
|
||||
]);
|
||||
e.preventDefault();
|
||||
return;
|
||||
}
|
||||
|
||||
// if a file is pasted, we will handle it here
|
||||
const files = Array.from(e.clipboardData.items)
|
||||
.filter((item) => item.kind === 'file')
|
||||
.map((item) => item.getAsFile())
|
||||
.filter((file) => file !== null);
|
||||
|
||||
if (files.length > 0) {
|
||||
e.preventDefault();
|
||||
extraContext.onFileAdded(files);
|
||||
}
|
||||
}}
|
||||
{...getRootProps()}
|
||||
>
|
||||
{!isGenerating && (
|
||||
|
||||
@@ -100,6 +100,16 @@ const SETTING_SECTIONS: SettingSection[] = [
|
||||
key,
|
||||
}) as SettingFieldInput
|
||||
),
|
||||
{
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: 'Paste length to file',
|
||||
key: 'pasteLongTextToFileLen',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: 'Parse PDF as image instead of text',
|
||||
key: 'pdfAsImage',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
@@ -452,10 +462,10 @@ function SettingsModalLongInput({
|
||||
label?: string;
|
||||
}) {
|
||||
return (
|
||||
<label className="form-control mb-2">
|
||||
<div className="label inline">{label || configKey}</div>
|
||||
<label className="form-control">
|
||||
<div className="label inline text-sm">{label || configKey}</div>
|
||||
<textarea
|
||||
className="textarea textarea-bordered h-24"
|
||||
className="textarea textarea-bordered h-24 mb-2"
|
||||
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
@@ -482,9 +492,7 @@ function SettingsModalShortInput({
|
||||
<>
|
||||
{/* on mobile, we simply show the help message here */}
|
||||
{helpMsg && (
|
||||
<div className="block md:hidden mb-1">
|
||||
<b>{label || configKey}</b>
|
||||
<br />
|
||||
<div className="block mb-1 opacity-75">
|
||||
<p className="text-xs">{helpMsg}</p>
|
||||
</div>
|
||||
)}
|
||||
@@ -493,11 +501,6 @@ function SettingsModalShortInput({
|
||||
<div tabIndex={0} role="button" className="font-bold hidden md:block">
|
||||
{label || configKey}
|
||||
</div>
|
||||
{helpMsg && (
|
||||
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
|
||||
{helpMsg}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<input
|
||||
type="text"
|
||||
|
||||
@@ -2,6 +2,17 @@ import { useState } from 'react';
|
||||
import { MessageExtra } from '../utils/types';
|
||||
import toast from 'react-hot-toast';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import * as pdfjs from 'pdfjs-dist';
|
||||
import pdfjsWorkerSrc from 'pdfjs-dist/build/pdf.worker.min.mjs?url';
|
||||
import { TextContent, TextItem } from 'pdfjs-dist/types/src/display/api';
|
||||
|
||||
pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
|
||||
|
||||
// This file handles uploading extra context items (a.k.a files)
|
||||
// It allows processing these kinds of files:
|
||||
// - image files (converted to base64)
|
||||
// - text files (including code files)
|
||||
// - pdf (converted to text)
|
||||
|
||||
// Interface describing the API returned by the hook
|
||||
export interface ChatExtraContextApi {
|
||||
@@ -13,7 +24,7 @@ export interface ChatExtraContextApi {
|
||||
}
|
||||
|
||||
export function useChatExtraContext(): ChatExtraContextApi {
|
||||
const { serverProps } = useAppContext();
|
||||
const { serverProps, config } = useAppContext();
|
||||
const [items, setItems] = useState<MessageExtra[]>([]);
|
||||
|
||||
const addItems = (newItems: MessageExtra[]) => {
|
||||
@@ -28,6 +39,8 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||
setItems([]);
|
||||
};
|
||||
|
||||
const isSupportVision = serverProps?.modalities?.vision;
|
||||
|
||||
const onFileAdded = (files: File[]) => {
|
||||
for (const file of files) {
|
||||
const mimeType = file.type;
|
||||
@@ -38,7 +51,7 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||
}
|
||||
|
||||
if (mimeType.startsWith('image/')) {
|
||||
if (!serverProps?.modalities?.vision) {
|
||||
if (!isSupportVision) {
|
||||
toast.error('Multimodal is not supported by this server or model.');
|
||||
break;
|
||||
}
|
||||
@@ -69,7 +82,43 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||
toast.error('Video and audio files are not supported yet.');
|
||||
break;
|
||||
} else if (mimeType.startsWith('application/pdf')) {
|
||||
toast.error('PDF files are not supported yet.');
|
||||
if (config.pdfAsImage && !isSupportVision) {
|
||||
toast(
|
||||
'Multimodal is not supported, PDF will be converted to text instead of image.'
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
const promise =
|
||||
config.pdfAsImage && isSupportVision
|
||||
? convertPDFToImage(file).then((base64Urls) => {
|
||||
addItems(
|
||||
base64Urls.map((base64Url) => ({
|
||||
type: 'imageFile',
|
||||
name: file.name,
|
||||
base64Url,
|
||||
}))
|
||||
);
|
||||
})
|
||||
: convertPDFToText(file).then((content) => {
|
||||
if (isSupportVision) {
|
||||
toast.success(
|
||||
'PDF file converted to text. You can also convert it to image, see in Settings.'
|
||||
);
|
||||
}
|
||||
addItems([
|
||||
{
|
||||
type: 'textFile',
|
||||
name: file.name,
|
||||
content,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
promise.catch((error) => {
|
||||
console.error(error);
|
||||
toast.error('Failed to parse PDF file.');
|
||||
});
|
||||
break;
|
||||
} else {
|
||||
// Because there can be many text file types (like code file), we will not check the mime type
|
||||
@@ -105,11 +154,69 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||
};
|
||||
}
|
||||
|
||||
async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
resolve(event.target.result as ArrayBuffer);
|
||||
} else {
|
||||
reject(new Error('Failed to read file.'));
|
||||
}
|
||||
};
|
||||
reader.readAsArrayBuffer(file);
|
||||
});
|
||||
}
|
||||
|
||||
async function convertPDFToText(file: File): Promise<string> {
|
||||
const buffer = await getFileAsBuffer(file);
|
||||
const pdf = await pdfjs.getDocument(buffer).promise;
|
||||
const numPages = pdf.numPages;
|
||||
const textContentPromises: Promise<TextContent>[] = [];
|
||||
for (let i = 1; i <= numPages; i++) {
|
||||
textContentPromises.push(
|
||||
pdf.getPage(i).then((page) => page.getTextContent())
|
||||
);
|
||||
}
|
||||
const textContents = await Promise.all(textContentPromises);
|
||||
const textItems = textContents.flatMap((textContent: TextContent) =>
|
||||
textContent.items.map((item) => (item as TextItem).str ?? '')
|
||||
);
|
||||
return textItems.join('\n');
|
||||
}
|
||||
|
||||
// returns list of base64 images
|
||||
async function convertPDFToImage(file: File): Promise<string[]> {
|
||||
const buffer = await getFileAsBuffer(file);
|
||||
const doc = await pdfjs.getDocument(buffer).promise;
|
||||
const pages: Promise<string>[] = [];
|
||||
|
||||
for (let i = 1; i <= doc.numPages; i++) {
|
||||
const page = await doc.getPage(i);
|
||||
const viewport = page.getViewport({ scale: 1.5 });
|
||||
const canvas = document.createElement('canvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
canvas.width = viewport.width;
|
||||
canvas.height = viewport.height;
|
||||
if (!ctx) {
|
||||
throw new Error('Failed to get 2D context from canvas');
|
||||
}
|
||||
const task = page.render({ canvasContext: ctx, viewport: viewport });
|
||||
pages.push(
|
||||
task.promise.then(() => {
|
||||
return canvas.toDataURL();
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return await Promise.all(pages);
|
||||
}
|
||||
|
||||
// WARN: vibe code below
|
||||
// This code is a heuristic to determine if a string is likely not binary.
|
||||
// It is necessary because input file can have various mime types which we don't have time to investigate.
|
||||
// For example, a python file can be text/plain, application/x-python, etc.
|
||||
export function isLikelyNotBinary(str: string): boolean {
|
||||
function isLikelyNotBinary(str: string): boolean {
|
||||
const options = {
|
||||
prefixLength: 1024 * 10, // Check the first 10KB of the string
|
||||
suspiciousCharThresholdRatio: 0.15, // Allow up to 15% suspicious chars
|
||||
|
||||
@@ -3,11 +3,11 @@ import react from '@vitejs/plugin-react';
|
||||
import { viteSingleFile } from 'vite-plugin-singlefile';
|
||||
import path from 'node:path';
|
||||
import fs from 'node:fs';
|
||||
import zlib from 'node:zlib';
|
||||
import * as fflate from 'fflate';
|
||||
|
||||
/* eslint-disable */
|
||||
|
||||
const MAX_BUNDLE_SIZE = 1.5 * 1024 * 1024; // only increase when absolutely necessary
|
||||
const MAX_BUNDLE_SIZE = 2 * 1024 * 1024; // only increase when absolutely necessary
|
||||
|
||||
const GUIDE_FOR_FRONTEND = `
|
||||
<!--
|
||||
@@ -33,9 +33,10 @@ const BUILD_PLUGINS = [
|
||||
},
|
||||
writeBundle() {
|
||||
const outputIndexHtml = path.join(config.build.outDir, 'index.html');
|
||||
const content =
|
||||
let content =
|
||||
GUIDE_FOR_FRONTEND + '\n' + fs.readFileSync(outputIndexHtml, 'utf-8');
|
||||
const compressed = zlib.gzipSync(Buffer.from(content, 'utf-8'), {
|
||||
content = content.replace(/\r/g, ''); // remove windows-style line endings
|
||||
const compressed = fflate.gzipSync(Buffer.from(content, 'utf-8'), {
|
||||
level: 9,
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user