mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 16:17:40 +02:00
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bbeb89d76c | |||
| ff806a110d | |||
| d5003b6e4d | |||
| 2635ac76e8 | |||
| 70a8309114 | |||
| c91faf997f | |||
| bf76ac77be | |||
| a09a00e502 | |||
| 2bacb1eb77 | |||
| d6e7b033a4 | |||
| fa595462ca | |||
| a817a22bc6 | |||
| eff06702b2 | |||
| e77056f9b2 | |||
| 935a340292 | |||
| d8794eecd5 | |||
| 36a694c965 | |||
| a4701c98f7 | |||
| 994118a183 | |||
| c84e6d6db5 | |||
| fa8feaed34 | |||
| 846262d787 | |||
| 6dcd824fce | |||
| d4b0c22f9e | |||
| e48034dfc9 | |||
| 048a490f76 | |||
| db44417b02 | |||
| d05fe1d7da | |||
| 0754b7b6fe | |||
| 09294365a9 | |||
| 63d93d1733 | |||
| c5a3bc39b1 | |||
| 9dbb372610 | |||
| 228e836344 | |||
| ed23489f42 | |||
| 457e2288c9 | |||
| e8ec7ab058 | |||
| 1a03cf47f6 | |||
| b97ebdc98f | |||
| 2098fd6169 | |||
| ab6120cde5 | |||
| c3c1505392 | |||
| 05e141a6b3 |
@@ -12,6 +12,8 @@ body:
|
||||
after recreating the CMake build directory and with `-DGGML_CCACHE=OFF`.
|
||||
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
|
||||
by clearing `~/.cache/ccache` (on Linux).
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: commit
|
||||
attributes:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
name: Bug (model use)
|
||||
description: Something goes wrong when using a model (in general, not specific to a single llama.cpp module).
|
||||
description: Something goes wrong when running a model (crashes, garbled outputs, etc.).
|
||||
title: "Eval bug: "
|
||||
labels: ["bug-unconfirmed", "model evaluation"]
|
||||
body:
|
||||
@@ -12,6 +12,8 @@ body:
|
||||
If you encountered the issue while using an external UI (e.g. ollama),
|
||||
please reproduce your issue using one of the examples/binaries in this repository.
|
||||
The `llama-completion` binary can be used for simple and reproducible model inference.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
|
||||
@@ -10,6 +10,8 @@ body:
|
||||
This issue template is intended for miscellaneous bugs that don't fit into any other category.
|
||||
If you encountered the issue while using an external UI (e.g. ollama),
|
||||
please reproduce your issue using one of the examples/binaries in this repository.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
- type: textarea
|
||||
id: version
|
||||
attributes:
|
||||
|
||||
@@ -8,6 +8,8 @@ body:
|
||||
value: |
|
||||
[Please post your idea first in Discussion if there is not yet a consensus for this enhancement request. This will help to keep this issue tracker focused on enhancements that the community has agreed needs to be implemented.](https://github.com/ggml-org/llama.cpp/discussions/categories/ideas)
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: checkboxes
|
||||
id: prerequisites
|
||||
attributes:
|
||||
|
||||
@@ -8,6 +8,8 @@ body:
|
||||
value: |
|
||||
Don't forget to check for any [duplicate research issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3A%22research+%F0%9F%94%AC%22)
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: checkboxes
|
||||
id: research-stage
|
||||
attributes:
|
||||
|
||||
@@ -9,6 +9,8 @@ body:
|
||||
Don't forget to [check for existing refactor issue tickets](https://github.com/ggml-org/llama.cpp/issues?q=is%3Aopen+is%3Aissue+label%3Arefactoring) in case it's already covered.
|
||||
Also you may want to check [Pull request refactor label as well](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Aopen+is%3Apr+label%3Arefactoring) for duplicates too.
|
||||
|
||||
Please fill out this template yourself, copypasting language model outputs is [strictly prohibited](https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md#ai-usage-policy).
|
||||
|
||||
- type: textarea
|
||||
id: background-description
|
||||
attributes:
|
||||
|
||||
@@ -4,6 +4,7 @@ General:
|
||||
- By very precise and concise when writing code, comments, explanations, etc.
|
||||
- PR and commit titles format: `<module> : <title>`. Lookup recents for examples
|
||||
- Don't try to build or run the code unless you are explicitly asked to do so
|
||||
- Use the `gh` CLI tool when querying PRs, issues, or other GitHub resources
|
||||
|
||||
Coding:
|
||||
- When in doubt, always refer to the CONTRIBUTING.md file of the project
|
||||
|
||||
+20
-11
@@ -248,6 +248,8 @@ std::vector<std::string> common_arg::get_env() const {
|
||||
|
||||
// Helper function to parse tensor buffer override strings
|
||||
static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
@@ -425,6 +427,10 @@ static bool parse_bool_value(const std::string & value) {
|
||||
}
|
||||
}
|
||||
|
||||
[[noreturn]] static void arg_removed(const std::string & msg) {
|
||||
throw std::invalid_argument("the argument has been removed. " + msg);
|
||||
}
|
||||
|
||||
//
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
@@ -803,6 +809,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
|
||||
if (dev_names.size() == 1 && dev_names[0] == "none") {
|
||||
devices.push_back(nullptr);
|
||||
} else {
|
||||
ggml_backend_load_all();
|
||||
for (const auto & device : dev_names) {
|
||||
auto * dev = ggml_backend_dev_by_name(device.c_str());
|
||||
if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
@@ -820,6 +827,7 @@ static void add_rpc_devices(const std::string & servers) {
|
||||
if (rpc_servers.empty()) {
|
||||
throw std::invalid_argument("no RPC servers specified");
|
||||
}
|
||||
ggml_backend_load_all();
|
||||
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!rpc_reg) {
|
||||
throw std::invalid_argument("failed to find RPC backend");
|
||||
@@ -1016,9 +1024,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
|
||||
params.use_color = tty_can_use_colors();
|
||||
|
||||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
|
||||
common_params_context ctx_arg(params);
|
||||
ctx_arg.print_usage = print_usage;
|
||||
ctx_arg.ex = ex;
|
||||
@@ -2275,6 +2280,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--list-devices"},
|
||||
"print list of available devices and exit",
|
||||
[](common_params &) {
|
||||
ggml_backend_load_all();
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
@@ -2864,7 +2870,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--tools"}, "TOOL1,TOOL2,...",
|
||||
"experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n"
|
||||
"specify \"all\" to enable all tools\n"
|
||||
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff",
|
||||
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff, get_datetime",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
@@ -3380,7 +3386,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-draft-poll", "--poll-draft"}, "<0|1>",
|
||||
"Use polling to wait for draft model work (default: same as --poll])",
|
||||
"Use polling to wait for draft model work (default: same as --poll)",
|
||||
[](common_params & params, int value) {
|
||||
params.speculative.draft.cpuparams.poll = value;
|
||||
}
|
||||
@@ -3715,35 +3721,35 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--draft", "--draft-n", "--draft-max"}, "N",
|
||||
"the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max");
|
||||
arg_removed("use --spec-draft-n-max or --spec-ngram-mod-n-max");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MAX"));
|
||||
add_opt(common_arg(
|
||||
{"--draft-min", "--draft-n-min"}, "N",
|
||||
"the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min");
|
||||
arg_removed("use --spec-draft-n-min or --spec-ngram-mod-n-min");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MIN"));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
"the argument has been removed. use the respective --spec-ngram-*-size-n or --spec-ngram-mod-n-match",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use the respective --spec-ngram-*-size-n");
|
||||
arg_removed("use the respective --spec-ngram-*-size-n");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-m"}, "N",
|
||||
"the argument has been removed. use the respective --spec-ngram-*-size-m",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use the respective --spec-ngram-*-size-m");
|
||||
arg_removed("use the respective --spec-ngram-*-size-m");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-min-hits"}, "N",
|
||||
"the argument has been removed. use the respective --spec-ngram-*-min-hits",
|
||||
[](common_params & /*params*/, int /*value*/) {
|
||||
throw std::invalid_argument("the argument has been removed. use the respective --spec-ngram-*-min-hits");
|
||||
arg_removed("use the respective --spec-ngram-*-min-hits");
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
@@ -3794,7 +3800,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-algorithm"}, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm),
|
||||
string_format(
|
||||
"diffusion algorithm: 0=DIFFUSION_ALGORITHM_ORIGIN, 1=DIFFUSION_ALGORITHM_ENTROPY_BASED, "
|
||||
"2=DIFFUSION_ALGORITHM_MARGIN_BASED, 3=DIFFUSION_ALGORITHM_RANDOM, "
|
||||
"4=DIFFUSION_ALGORITHM_CONFIDENCE_BASED (default: %d)", params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
|
||||
@@ -136,10 +136,10 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
|
||||
if (!end.empty()) {
|
||||
if (!start.empty()) {
|
||||
// Standard tag-based: optional(<think>reasoning</think>)
|
||||
return p.optional(start + p.reasoning(p.until(end)) + end + p.space());
|
||||
return p.optional(p.optspace(start) + p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end));
|
||||
}
|
||||
// Delimiter-style (empty start)
|
||||
return p.optional(p.reasoning(p.until(end)) + end + p.space());
|
||||
return p.optional(p.reasoning(p.until(trim_whitespace(end))) + p.optspace(end));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,7 +186,6 @@ common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const
|
||||
common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
// Build effective field names with dot notation if function_field is set
|
||||
std::string name_field = format.name_field;
|
||||
@@ -225,8 +224,7 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
|
||||
tool_start = format.per_call_start;
|
||||
}
|
||||
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(p.until(tool_start)))) + tools_parser +
|
||||
p.end();
|
||||
return ctx.reasoning_parser + p.optional(p.content(p.until(tool_start))) + tools_parser + p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
@@ -270,7 +268,6 @@ common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p,
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
@@ -336,14 +333,12 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
|
||||
|
||||
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
|
||||
p.end();
|
||||
return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
auto until_suffix = p.rule("until-suffix", p.until(arguments.value_suffix));
|
||||
|
||||
@@ -471,8 +466,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
|
||||
std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker);
|
||||
return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls +
|
||||
p.end();
|
||||
return ctx.reasoning_parser + p.optional(p.content(content_before_tools)) + tool_calls + p.end();
|
||||
}
|
||||
|
||||
} // namespace autoparser
|
||||
|
||||
@@ -342,7 +342,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (left_trimmed.empty() && !diff.right.empty()) {
|
||||
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
|
||||
if (start.empty()) {
|
||||
start = trim_leading_whitespace(diff.right);
|
||||
start = diff.right;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -353,7 +353,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
|
||||
start = seg[seg.size() - 2].value;
|
||||
}
|
||||
end = trim_trailing_whitespace(diff.left);
|
||||
end = diff.left;
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -445,14 +445,14 @@ void analyze_reasoning::compare_reasoning_scope() {
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
start = result.tags["pre"];
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
end = result.tags["post"];
|
||||
} else {
|
||||
auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())));
|
||||
});
|
||||
result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
end = result.tags["post"];
|
||||
} else {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Unable to extract reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__);
|
||||
mode = reasoning_mode::NONE;
|
||||
|
||||
@@ -816,6 +816,32 @@ common_peg_parser common_chat_peg_builder::prefix(const std::string & s, const s
|
||||
return literal(s.substr(0, s.rfind(delimiter)));
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::optspace(const std::string & tag) {
|
||||
auto parser = eps();
|
||||
size_t end_of_prefix_space = tag.size();
|
||||
size_t start_of_suffix_space = tag.size();
|
||||
for (size_t i = 0; i < tag.size(); i++) {
|
||||
if (!std::isspace(tag[i])) {
|
||||
end_of_prefix_space = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = tag.size(); i > 0; i--) {
|
||||
if (!std::isspace(tag[i - 1])) {
|
||||
start_of_suffix_space = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < end_of_prefix_space; i++) {
|
||||
parser += optional(literal(std::string(1, tag[i])));
|
||||
}
|
||||
parser += literal(tag.substr(end_of_prefix_space, start_of_suffix_space - end_of_prefix_space));
|
||||
for (size_t i = start_of_suffix_space; i < tag.size(); i++) {
|
||||
parser += optional(literal(std::string(1, tag[i])));
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
|
||||
common_peg_parser common_chat_peg_builder::standard_json_tools(
|
||||
const std::string & section_start,
|
||||
const std::string & section_end,
|
||||
|
||||
@@ -96,6 +96,9 @@ class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
// Return a parser that parses the prefix of a string, up to a given delimiter.
|
||||
common_peg_parser prefix(const std::string & s, const std::string & delimiter = {});
|
||||
|
||||
// Return a parser that parses all elements of tag, but leading and trailing spaces are optional
|
||||
common_peg_parser optspace(const std::string & tag);
|
||||
|
||||
// Legacy-compatible helper for building standard JSON tool calls
|
||||
// Used by tests and manual parsers
|
||||
// name_key/args_key: JSON key names for function name and arguments
|
||||
|
||||
+29
-20
@@ -2116,22 +2116,38 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::string common_chat_templates_generation_prompt(const common_chat_template & tmpl, const autoparser::generation_params & inputs) {
|
||||
autoparser::generation_params params = inputs;
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
|
||||
size_t prefix_len = 0;
|
||||
size_t min_size = std::min(no_gen_prompt.size(), gen_prompt.size());
|
||||
while (prefix_len < min_size && no_gen_prompt[prefix_len] == gen_prompt[prefix_len]) {
|
||||
prefix_len++;
|
||||
}
|
||||
return gen_prompt.substr(prefix_len);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs) {
|
||||
autoparser::generation_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl =
|
||||
params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
params.enable_thinking = inputs.enable_thinking;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
@@ -2153,14 +2169,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
workaround::func_args_not_string(params.messages);
|
||||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right + diff.suffix;
|
||||
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.generation_prompt = common_chat_templates_generation_prompt(tmpl, params);
|
||||
|
||||
params.extra_context = common_chat_extra_context();
|
||||
for (auto el : inputs.chat_template_kwargs) {
|
||||
@@ -2212,8 +2221,8 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
auto_params.thinking_start_tag = autoparser.reasoning.start;
|
||||
auto_params.thinking_end_tag = autoparser.reasoning.end;
|
||||
auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start);
|
||||
auto_params.thinking_end_tag = trim_whitespace(autoparser.reasoning.end);
|
||||
}
|
||||
auto_params.generation_prompt = params.generation_prompt;
|
||||
common_peg_arena arena;
|
||||
|
||||
@@ -158,6 +158,8 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
|
||||
for (size_t i = 0; i < cur_p->size; i++) {
|
||||
if (cur_p->data[i].id != forced) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
} else {
|
||||
cur_p->data[i].logit = +INFINITY; // force the token
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,14 +252,14 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
|
||||
size_t create_checkpoint(int n_tokens_prompt) {
|
||||
int slot_id = 0;
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.n_tokens = n_tokens_prompt;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
@@ -272,7 +272,7 @@ struct common_speculative_state_draft : public common_speculative_state {
|
||||
size_t restore_checkpoint() {
|
||||
int slot_id = 0;
|
||||
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (n != ckpt.size()) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size());
|
||||
|
||||
+20
-4
@@ -2889,6 +2889,20 @@ class LlamaModel(TextModel):
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape))
|
||||
|
||||
def _repack_nvfp4(self, name: str, weight: Tensor, scale: Tensor, scale2: Tensor, input_scale: Tensor):
|
||||
# Mirror the BF16 Q/K RoPE permutation site in modify_tensors; the NVFP4 path bypasses it.
|
||||
if self.undo_permute:
|
||||
n_head = self.find_hparam(["n_heads", "num_attention_heads"], optional=True)
|
||||
n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"], optional=True)
|
||||
if n_head is not None:
|
||||
if name.endswith("q_proj.weight"):
|
||||
weight = LlamaModel.permute(weight, n_head, n_head)
|
||||
scale = LlamaModel.permute(scale, n_head, n_head)
|
||||
elif name.endswith("k_proj.weight"):
|
||||
weight = LlamaModel.permute(weight, n_head, n_kv_head)
|
||||
scale = LlamaModel.permute(scale, n_head, n_kv_head)
|
||||
super()._repack_nvfp4(name, weight, scale, scale2, input_scale)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
@@ -12702,11 +12716,12 @@ class MistralModel(LlamaModel):
|
||||
def set_mistral_config(gguf_writer: gguf.GGUFWriter, hparams: dict):
|
||||
if "yarn" in hparams:
|
||||
yarn_params = hparams["yarn"]
|
||||
mscale_all_dim = 1.0 if not yarn_params["apply_scale"] else 0.0
|
||||
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||
gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||
gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||
gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||
gguf_writer.add_rope_scaling_yarn_log_mul(mscale_all_dim)
|
||||
gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||
|
||||
if "llama_4_scaling" in hparams:
|
||||
@@ -13232,17 +13247,18 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
}
|
||||
|
||||
# only used when byteswapping data. Only correct size is needed
|
||||
# TODO: uncomment uint64, uint32, and uint16, ref: https://github.com/pytorch/pytorch/issues/58734
|
||||
_dtype_byteswap_map: dict[torch.dtype, type] = {
|
||||
torch.float64: np.float64,
|
||||
torch.float32: np.float32,
|
||||
torch.bfloat16: np.float16,
|
||||
torch.float16: np.float16,
|
||||
torch.int64: np.int64,
|
||||
torch.uint64: np.uint64,
|
||||
# torch.uint64: np.uint64,
|
||||
torch.int32: np.int32,
|
||||
torch.uint32: np.uint32,
|
||||
# torch.uint32: np.uint32,
|
||||
torch.int16: np.int16,
|
||||
torch.uint16: np.uint16,
|
||||
# torch.uint16: np.uint16,
|
||||
torch.int8: np.int8,
|
||||
torch.uint8: np.uint8,
|
||||
torch.bool: np.uint8,
|
||||
|
||||
+108
-21
@@ -33,18 +33,18 @@ An example to use this approach can be the rewriting of source code by a LLM.
|
||||
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
|
||||
|
||||
```
|
||||
llama-server [...] --spec-type ngram-simple --draft-max 64
|
||||
llama-server [...] --spec-type ngram-simple --spec-draft-n-max 64
|
||||
```
|
||||
|
||||
#### n-gram Map Key (`ngram-map-k`)
|
||||
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`, default is 1) before generating drafts.
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-map-k-min-hits`, default is 1) before generating drafts.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:**
|
||||
```
|
||||
llama-server [...] --spec-type ngram-map-k --draft-max 64
|
||||
llama-server [...] --spec-type ngram-map-k --spec-draft-n-max 64
|
||||
```
|
||||
|
||||
#### n-gram Map Key-4-Values (`ngram-map-k4v`)
|
||||
@@ -55,7 +55,7 @@ The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:** Server options to be used if there are a lot of longer repetitions.
|
||||
```
|
||||
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2 --draft-max 64
|
||||
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-map-k4v-size-n 8 --spec-ngram-map-k4v-size-m 8 --spec-ngram-map-k4v-min-hits 2 --spec-draft-n-max 64
|
||||
```
|
||||
|
||||
### n-gram Mod (`ngram-mod`)
|
||||
@@ -80,9 +80,9 @@ Currently, a single hash pool is shared across all server slots, so different re
|
||||
# notes:
|
||||
# - small `n` are not recommended
|
||||
# - MoEs require long drafts
|
||||
# - dense models: can reduce `--draft-min` and `--draft-max`
|
||||
# - dense models: can reduce `--spec-ngram-mod-n-min` and `--spec-ngram-mod-n-max`
|
||||
|
||||
llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64
|
||||
llama-server ... --spec-type ngram-mod --spec-ngram-mod-n-match 24 --spec-ngram-mod-n-min 48 --spec-ngram-mod-n-max 64
|
||||
```
|
||||
|
||||
Applications:
|
||||
@@ -105,21 +105,90 @@ Example Video:
|
||||
|
||||
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
|
||||
|
||||
### General Speculative Parameters
|
||||
|
||||
```
|
||||
--draft, --draft-n, --draft-max N number of tokens to draft for speculative decoding (default: 16)
|
||||
(env: LLAMA_ARG_DRAFT_MAX)
|
||||
--draft-min, --draft-n-min N minimum number of draft tokens to use for speculative decoding
|
||||
(default: 0)
|
||||
(env: LLAMA_ARG_DRAFT_MIN)
|
||||
[...]
|
||||
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
type of speculative decoding to use when no draft model is provided
|
||||
(default: none)
|
||||
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
|
||||
of lookup n-gram (default: 12)
|
||||
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
|
||||
of draft m-gram (default: 48)
|
||||
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
|
||||
(env: LLAMA_ARG_SPEC_TYPE)
|
||||
--spec-default use default speculative decoding
|
||||
```
|
||||
|
||||
### Draft Model Parameters
|
||||
|
||||
```
|
||||
--spec-draft-model, -md, --model-draft FNAME
|
||||
draft model for speculative decoding (default: unused)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_MODEL)
|
||||
--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft <user>/<model>[:quant]
|
||||
HuggingFace repository for the draft model
|
||||
--spec-draft-n-max N
|
||||
number of tokens to draft for speculative decoding (default: 16)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_N_MAX)
|
||||
--spec-draft-n-min N
|
||||
minimum number of draft tokens to use for speculative decoding (default: 0)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_N_MIN)
|
||||
--spec-draft-p-split, --draft-p-split P
|
||||
speculative decoding split probability (default: 0.10)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT)
|
||||
--spec-draft-p-min, --draft-p-min P
|
||||
minimum speculative decoding probability (greedy) (default: 0.75)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_P_MIN)
|
||||
--spec-draft-ctx-size, -cd, --ctx-size-draft N
|
||||
size of the prompt context for the draft model (default: 0, 0 = loaded from model)
|
||||
(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE)
|
||||
--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N
|
||||
max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
|
||||
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT)
|
||||
--spec-draft-device, -devd, --device-draft <dev1,dev2,..>
|
||||
comma-separated list of devices to use for offloading the draft model
|
||||
--spec-draft-replace, --spec-replace TARGET DRAFT
|
||||
translate the string in TARGET into DRAFT if the draft model and main model are not compatible
|
||||
```
|
||||
|
||||
### n-gram Mod Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-mod-n-match N
|
||||
ngram-mod lookup length (default: 24)
|
||||
--spec-ngram-mod-n-min N
|
||||
minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48)
|
||||
--spec-ngram-mod-n-max N
|
||||
maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64)
|
||||
```
|
||||
|
||||
### n-gram Simple Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-simple-size-n N
|
||||
ngram size N for ngram-simple speculative decoding, length of lookup n-gram (default: 12)
|
||||
--spec-ngram-simple-size-m N
|
||||
ngram size M for ngram-simple speculative decoding, length of draft m-gram (default: 48)
|
||||
--spec-ngram-simple-min-hits N
|
||||
minimum hits for ngram-simple speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### n-gram Map Key Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-map-k-size-n N
|
||||
ngram size N for ngram-map-k speculative decoding, length of lookup n-gram (default: 12)
|
||||
--spec-ngram-map-k-size-m N
|
||||
ngram size M for ngram-map-k speculative decoding, length of draft m-gram (default: 48)
|
||||
--spec-ngram-map-k-min-hits N
|
||||
minimum hits for ngram-map-k speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### n-gram Map Key-4-Values Parameters
|
||||
|
||||
```
|
||||
--spec-ngram-map-k4v-size-n N
|
||||
ngram size N for ngram-map-k4v speculative decoding, length of lookup n-gram (default: 12)
|
||||
--spec-ngram-map-k4v-size-m N
|
||||
ngram size M for ngram-map-k4v speculative decoding, length of draft m-gram (default: 48)
|
||||
--spec-ngram-map-k4v-min-hits N
|
||||
minimum hits for ngram-map-k4v speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### `--spec-type TYPE`
|
||||
@@ -140,21 +209,40 @@ Specifies a type of speculative decoding without draft model.
|
||||
./llama-server [...] --spec-type ngram-simple
|
||||
```
|
||||
|
||||
### `--spec-ngram-size-n N`
|
||||
### `--spec-ngram-*-size-n N`
|
||||
|
||||
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
|
||||
The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
|
||||
|
||||
### `--spec-ngram-size-m M`
|
||||
Each n-gram implementation has its own parameter:
|
||||
|
||||
- `--spec-ngram-simple-size-n` for `ngram-simple`
|
||||
- `--spec-ngram-map-k-size-n` for `ngram-map-k`
|
||||
- `--spec-ngram-map-k4v-size-n` for `ngram-map-k4v`
|
||||
- `--spec-ngram-mod-n-match` for `ngram-mod`
|
||||
|
||||
### `--spec-ngram-*-size-m M`
|
||||
|
||||
Sets the size M of the draft m-gram for n-gram map based speculative decoding.
|
||||
The m-gram size determines how many tokens to draft when a match is found.
|
||||
Larger values can provide more speedup but may reduce acceptance rate.
|
||||
|
||||
### `--spec-ngram-min-hits H`
|
||||
Each n-gram implementation has its own parameter:
|
||||
|
||||
- `--spec-ngram-simple-size-m` for `ngram-simple`
|
||||
- `--spec-ngram-map-k-size-m` for `ngram-map-k`
|
||||
- `--spec-ngram-map-k4v-size-m` for `ngram-map-k4v`
|
||||
|
||||
### `--spec-ngram-*-min-hits H`
|
||||
|
||||
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
|
||||
|
||||
Each n-gram implementation has its own parameter:
|
||||
|
||||
- `--spec-ngram-simple-min-hits` for `ngram-simple`
|
||||
- `--spec-ngram-map-k-min-hits` for `ngram-map-k`
|
||||
- `--spec-ngram-map-k4v-min-hits` for `ngram-map-k4v`
|
||||
|
||||
## Statistics
|
||||
Each speculative decoding implementation prints statistics.
|
||||
|
||||
@@ -180,4 +268,3 @@ statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts
|
||||
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
|
||||
- `#acc tokens`: number of tokens accepted by the main model
|
||||
- `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance).
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
set(TARGET llama-diffusion)
|
||||
add_library(${TARGET} STATIC diffusion.cpp diffusion.h)
|
||||
target_link_libraries(${TARGET} PUBLIC llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PUBLIC cxx_std_17)
|
||||
|
||||
set(TARGET llama-diffusion-cli)
|
||||
add_executable(${TARGET} diffusion-cli.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE llama-diffusion llama llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -12,11 +12,11 @@ The diffusion CLI supports various parameters to control the generation process:
|
||||
### Core Diffusion Parameters
|
||||
- `--diffusion-steps`: Number of diffusion steps (default: 256)
|
||||
- `--diffusion-algorithm`: Algorithm for token selection
|
||||
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: MARGIN_BASED - Margin-based selection
|
||||
- `3`: RANDOM - Random selection
|
||||
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- `0`: DIFFUSION_ALGORITHM_ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: DIFFUSION_ALGORITHM_ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: DIFFUSION_ALGORITHM_MARGIN_BASED - Margin-based selection
|
||||
- `3`: DIFFUSION_ALGORITHM_RANDOM - Random selection
|
||||
- `4`: DIFFUSION_ALGORITHM_CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- More documentation here https://github.com/DreamLM/Dream
|
||||
- `--diffusion-visual`: Enable live visualization during generation
|
||||
|
||||
|
||||
@@ -1,127 +1,23 @@
|
||||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "diffusion.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
|
||||
|
||||
// Unified transfer scheduling methods
|
||||
enum transfer_schedule {
|
||||
TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
|
||||
BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
|
||||
};
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps = 0;
|
||||
float temperature = 0;
|
||||
llama_token mask_token_id = LLAMA_TOKEN_NULL;
|
||||
diffusion_step_callback_t step_callback = nullptr;
|
||||
void * step_callback_user_data = nullptr;
|
||||
int32_t seed = 0;
|
||||
bool visual_mode = false;
|
||||
bool shift_logits = false; // Shift logits by -1 after decode
|
||||
|
||||
float top_p = 0.;
|
||||
int32_t top_k = 0.;
|
||||
|
||||
diffusion_algorithm algorithm = CONFIDENCE_BASED;
|
||||
transfer_schedule schedule = TIMESTEP_BASED;
|
||||
|
||||
float cfg_scale = 0.; // Config scale for classifier-free guidance
|
||||
float eps = 0.; // Timestep scheduling
|
||||
int32_t block_length = 0; // Block size (for block scheduling)
|
||||
float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
|
||||
bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
|
||||
|
||||
int32_t max_length = 0; // Maximum sequence length
|
||||
};
|
||||
|
||||
struct callback_data {
|
||||
diffusion_params * diff_params;
|
||||
const llama_vocab * vocab;
|
||||
int32_t n_input;
|
||||
};
|
||||
|
||||
static float calculate_confidence(const llama_token_data_array & cur_p,
|
||||
diffusion_algorithm algorithm,
|
||||
std::mt19937 & rng) {
|
||||
switch (algorithm) {
|
||||
case CONFIDENCE_BASED:
|
||||
return cur_p.data[cur_p.selected].p; // Selected token probability
|
||||
|
||||
case ENTROPY_BASED:
|
||||
{
|
||||
float entropy = 0.0f;
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t i = 0; i < cur_p.size; i++) {
|
||||
float prob = cur_p.data[i].p;
|
||||
entropy += prob * logf(prob + epsilon);
|
||||
}
|
||||
return -entropy; // Higher entropy = lower confidence
|
||||
}
|
||||
|
||||
case MARGIN_BASED:
|
||||
return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
|
||||
|
||||
case RANDOM:
|
||||
{
|
||||
std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
|
||||
return uniform(rng); // Random confidence
|
||||
}
|
||||
|
||||
case ORIGIN:
|
||||
return cur_p.data[cur_p.selected].p;
|
||||
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Unified transfer count calculation function
|
||||
static int32_t calculate_transfer_count(int32_t step,
|
||||
int32_t total_steps,
|
||||
int32_t remaining_masked,
|
||||
transfer_schedule schedule,
|
||||
float eps,
|
||||
const std::vector<int32_t> & num_transfer_tokens = {}) {
|
||||
switch (schedule) {
|
||||
case TIMESTEP_BASED:
|
||||
{
|
||||
float t = 1.0f - (float) step / total_steps * (1.0f - eps);
|
||||
float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
|
||||
float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
return (int32_t) (remaining_masked * p_transfer);
|
||||
}
|
||||
|
||||
case BLOCK_BASED:
|
||||
if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
|
||||
return num_transfer_tokens[step];
|
||||
}
|
||||
return remaining_masked / (total_steps - step); // Fallback
|
||||
|
||||
default:
|
||||
return remaining_masked / (total_steps - step);
|
||||
}
|
||||
}
|
||||
|
||||
static bool diffusion_step_callback(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
@@ -176,341 +72,6 @@ static bool diffusion_step_callback(int32_t step,
|
||||
return true;
|
||||
}
|
||||
|
||||
static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
|
||||
if (temperature == 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<double> uniform(0.0, 1.0);
|
||||
for (int32_t i = 0; i < n_vocab; i++) {
|
||||
double noise = uniform(rng);
|
||||
// Prevent log(0)
|
||||
noise = std::max(noise, 1e-20);
|
||||
double gumbel_noise = std::pow(-std::log(noise), temperature);
|
||||
logits[i] = std::exp(logits[i]) / gumbel_noise;
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
|
||||
std::vector<int32_t> num_transfer_tokens(steps);
|
||||
|
||||
int32_t base = mask_count / steps;
|
||||
int32_t remainder = mask_count % steps;
|
||||
|
||||
for (int32_t i = 0; i < steps; i++) {
|
||||
num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
|
||||
}
|
||||
|
||||
return num_transfer_tokens;
|
||||
}
|
||||
|
||||
static void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated) {
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(params.max_length);
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(params.max_length);
|
||||
|
||||
// Setup sampler chain
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(params.max_length, 0, 1);
|
||||
batch.n_tokens = params.max_length;
|
||||
|
||||
// Pre-allocate buffers for CFG if needed
|
||||
int32_t logits_size = n_vocab * params.max_length;
|
||||
std::vector<float> cond_logits_buffer;
|
||||
std::vector<llama_token> un_x_buffer;
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
cond_logits_buffer.resize(logits_size);
|
||||
un_x_buffer.resize(params.max_length);
|
||||
}
|
||||
|
||||
// For block-based processing
|
||||
std::vector<int32_t> num_transfer_tokens;
|
||||
int32_t num_blocks = 1;
|
||||
int32_t steps_per_block = params.steps;
|
||||
|
||||
if (params.schedule == BLOCK_BASED) {
|
||||
GGML_ASSERT(params.max_length % params.block_length == 0);
|
||||
num_blocks = params.max_length / params.block_length;
|
||||
GGML_ASSERT(params.steps % num_blocks == 0);
|
||||
steps_per_block = params.steps / num_blocks;
|
||||
}
|
||||
|
||||
std::vector<float> confidence(params.max_length);
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
int64_t time_start = ggml_time_us();
|
||||
|
||||
for (int block_num = 0; block_num < num_blocks; block_num++) {
|
||||
int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
|
||||
int32_t block_end = (params.schedule == BLOCK_BASED) ?
|
||||
std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
|
||||
params.max_length;
|
||||
|
||||
// Count masked tokens in current block for block-based processing
|
||||
if (params.schedule == BLOCK_BASED) {
|
||||
int32_t block_mask_count = 0;
|
||||
for (int i = block_start; i < block_end; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
block_mask_count++;
|
||||
}
|
||||
}
|
||||
num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
|
||||
}
|
||||
|
||||
for (int32_t step = 0; step < steps_per_block; step++) {
|
||||
int32_t global_step = block_num * steps_per_block + step;
|
||||
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(
|
||||
global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup batch
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
float * logits = nullptr;
|
||||
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate conditional");
|
||||
break;
|
||||
}
|
||||
float * cond_logits_ptr = llama_get_logits(ctx);
|
||||
std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
|
||||
|
||||
// Unconditional generation (mask input)
|
||||
std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
|
||||
for (int32_t i = 0; i < n_input; i++) {
|
||||
un_x_buffer[i] = params.mask_token_id;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = un_x_buffer[i];
|
||||
}
|
||||
ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate unconditional");
|
||||
break;
|
||||
}
|
||||
float * uncond_logits = llama_get_logits(ctx);
|
||||
|
||||
// Apply CFG
|
||||
for (int32_t i = 0; i < logits_size; i++) {
|
||||
cond_logits_buffer[i] =
|
||||
uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
|
||||
}
|
||||
logits = cond_logits_buffer.data();
|
||||
} else {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
|
||||
break;
|
||||
}
|
||||
logits = llama_get_logits(ctx);
|
||||
}
|
||||
|
||||
if (!logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
if (params.shift_logits) {
|
||||
return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
|
||||
}
|
||||
return logits + (pos) *n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
// For block-based, only consider current block
|
||||
if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (params.add_gumbel_noise && params.temperature > 0.0f) {
|
||||
add_gumbel_noise(logits, n_vocab, params.temperature, rng);
|
||||
}
|
||||
|
||||
if (params.algorithm == ORIGIN) {
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
float p_transfer = (float) transfer_count / mask_positions.size();
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
(size_t) n_vocab,
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float conf = calculate_confidence(cur_p, params.algorithm, rng);
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(conf, i);
|
||||
}
|
||||
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
|
||||
if (transfer_count > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(),
|
||||
confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
|
||||
confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
for (size_t i = 0; i < confidences.size(); i++) {
|
||||
float conf_logit = confidences[i].first / params.alg_temp;
|
||||
conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
conf_candidates.data(),
|
||||
conf_candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int32_t selected_idx = conf_array.selected;
|
||||
int32_t mask_idx = selected_idx;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0,
|
||||
total_time / 1000.0 / params.steps,
|
||||
total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = params.max_length;
|
||||
}
|
||||
|
||||
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
|
||||
if (!use_chat_template) {
|
||||
return prompt;
|
||||
@@ -631,10 +192,10 @@ int main(int argc, char ** argv) {
|
||||
GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
|
||||
|
||||
if (params.diffusion.eps) {
|
||||
diff_params.schedule = TIMESTEP_BASED;
|
||||
diff_params.schedule = DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED;
|
||||
diff_params.eps = params.diffusion.eps;
|
||||
} else if (params.diffusion.block_length) {
|
||||
diff_params.schedule = BLOCK_BASED;
|
||||
diff_params.schedule = DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED;
|
||||
diff_params.block_length = params.diffusion.block_length;
|
||||
}
|
||||
|
||||
@@ -653,8 +214,17 @@ int main(int argc, char ** argv) {
|
||||
callback_data cb_data = { &diff_params, vocab, n_input };
|
||||
diff_params.step_callback_user_data = &cb_data;
|
||||
|
||||
const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
|
||||
const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
|
||||
const char * alg_names[] = {
|
||||
"DIFFUSION_ALGORITHM_ORIGIN",
|
||||
"DIFFUSION_ALGORITHM_ENTROPY_BASED",
|
||||
"DIFFUSION_ALGORITHM_MARGIN_BASED",
|
||||
"DIFFUSION_ALGORITHM_RANDOM",
|
||||
"DIFFUSION_ALGORITHM_CONFIDENCE_BASED",
|
||||
};
|
||||
const char * sched_names[] = {
|
||||
"DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED",
|
||||
"DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED",
|
||||
};
|
||||
const char * alg_name =
|
||||
(diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
|
||||
const char * sched_name =
|
||||
@@ -666,11 +236,11 @@ int main(int argc, char ** argv) {
|
||||
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
|
||||
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
|
||||
if (diff_params.schedule == TIMESTEP_BASED) {
|
||||
if (diff_params.schedule == DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED) {
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
|
||||
}
|
||||
if (diff_params.schedule == BLOCK_BASED) {
|
||||
if (diff_params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) {
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
#include "diffusion.h"
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
static float calculate_confidence(const llama_token_data_array & cur_p,
|
||||
diffusion_algorithm algorithm,
|
||||
std::mt19937 & rng) {
|
||||
switch (algorithm) {
|
||||
case DIFFUSION_ALGORITHM_CONFIDENCE_BASED:
|
||||
return cur_p.data[cur_p.selected].p; // Selected token probability
|
||||
|
||||
case DIFFUSION_ALGORITHM_ENTROPY_BASED:
|
||||
{
|
||||
float entropy = 0.0f;
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t i = 0; i < cur_p.size; i++) {
|
||||
float prob = cur_p.data[i].p;
|
||||
entropy += prob * logf(prob + epsilon);
|
||||
}
|
||||
return -entropy; // Higher entropy = lower confidence
|
||||
}
|
||||
|
||||
case DIFFUSION_ALGORITHM_MARGIN_BASED:
|
||||
return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
|
||||
|
||||
case DIFFUSION_ALGORITHM_RANDOM:
|
||||
{
|
||||
std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
|
||||
return uniform(rng); // Random confidence
|
||||
}
|
||||
|
||||
case DIFFUSION_ALGORITHM_ORIGIN:
|
||||
return cur_p.data[cur_p.selected].p;
|
||||
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Unified transfer count calculation function
|
||||
static int32_t calculate_transfer_count(int32_t step,
|
||||
int32_t total_steps,
|
||||
int32_t remaining_masked,
|
||||
diffusion_transfer_schedule schedule,
|
||||
float eps,
|
||||
const std::vector<int32_t> & num_transfer_tokens = {}) {
|
||||
switch (schedule) {
|
||||
case DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED:
|
||||
{
|
||||
float t = 1.0f - (float) step / total_steps * (1.0f - eps);
|
||||
float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
|
||||
float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
return (int32_t) (remaining_masked * p_transfer);
|
||||
}
|
||||
|
||||
case DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED:
|
||||
if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
|
||||
return num_transfer_tokens[step];
|
||||
}
|
||||
return remaining_masked / (total_steps - step); // Fallback
|
||||
|
||||
default:
|
||||
return remaining_masked / (total_steps - step);
|
||||
}
|
||||
}
|
||||
|
||||
static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
|
||||
if (temperature == 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<double> uniform(0.0, 1.0);
|
||||
for (int32_t i = 0; i < n_vocab; i++) {
|
||||
double noise = uniform(rng);
|
||||
// Prevent log(0)
|
||||
noise = std::max(noise, 1e-20);
|
||||
double gumbel_noise = std::pow(-std::log(noise), temperature);
|
||||
logits[i] = std::exp(logits[i]) / gumbel_noise;
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
|
||||
std::vector<int32_t> num_transfer_tokens(steps);
|
||||
|
||||
int32_t base = mask_count / steps;
|
||||
int32_t remainder = mask_count % steps;
|
||||
|
||||
for (int32_t i = 0; i < steps; i++) {
|
||||
num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
|
||||
}
|
||||
|
||||
return num_transfer_tokens;
|
||||
}
|
||||
|
||||
void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated) {
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(params.max_length);
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(params.max_length);
|
||||
|
||||
// Setup sampler chain
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(params.max_length, 0, 1);
|
||||
batch.n_tokens = params.max_length;
|
||||
|
||||
// Pre-allocate buffers for CFG if needed
|
||||
int32_t logits_size = n_vocab * params.max_length;
|
||||
std::vector<float> cond_logits_buffer;
|
||||
std::vector<llama_token> un_x_buffer;
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
cond_logits_buffer.resize(logits_size);
|
||||
un_x_buffer.resize(params.max_length);
|
||||
}
|
||||
|
||||
// For block-based processing
|
||||
std::vector<int32_t> num_transfer_tokens;
|
||||
int32_t num_blocks = 1;
|
||||
int32_t steps_per_block = params.steps;
|
||||
|
||||
if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) {
|
||||
GGML_ASSERT(params.max_length % params.block_length == 0);
|
||||
num_blocks = params.max_length / params.block_length;
|
||||
GGML_ASSERT(params.steps % num_blocks == 0);
|
||||
steps_per_block = params.steps / num_blocks;
|
||||
}
|
||||
|
||||
std::vector<float> confidence(params.max_length);
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
int64_t time_start = ggml_time_us();
|
||||
|
||||
for (int block_num = 0; block_num < num_blocks; block_num++) {
|
||||
int32_t block_start = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
|
||||
int32_t block_end = (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) ?
|
||||
std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
|
||||
params.max_length;
|
||||
|
||||
// Count masked tokens in current block for block-based processing
|
||||
if (params.schedule == DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED) {
|
||||
int32_t block_mask_count = 0;
|
||||
for (int i = block_start; i < block_end; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
block_mask_count++;
|
||||
}
|
||||
}
|
||||
num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
|
||||
}
|
||||
|
||||
for (int32_t step = 0; step < steps_per_block; step++) {
|
||||
int32_t global_step = block_num * steps_per_block + step;
|
||||
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(
|
||||
global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup batch
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
float * logits = nullptr;
|
||||
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate conditional");
|
||||
break;
|
||||
}
|
||||
float * cond_logits_ptr = llama_get_logits(ctx);
|
||||
std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
|
||||
|
||||
// Unconditional generation (mask input)
|
||||
std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
|
||||
for (int32_t i = 0; i < n_input; i++) {
|
||||
un_x_buffer[i] = params.mask_token_id;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = un_x_buffer[i];
|
||||
}
|
||||
ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate unconditional");
|
||||
break;
|
||||
}
|
||||
float * uncond_logits = llama_get_logits(ctx);
|
||||
|
||||
// Apply CFG
|
||||
for (int32_t i = 0; i < logits_size; i++) {
|
||||
cond_logits_buffer[i] =
|
||||
uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
|
||||
}
|
||||
logits = cond_logits_buffer.data();
|
||||
} else {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
|
||||
break;
|
||||
}
|
||||
logits = llama_get_logits(ctx);
|
||||
}
|
||||
|
||||
if (!logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
if (params.shift_logits) {
|
||||
return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
|
||||
}
|
||||
return logits + pos * n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
// For block-based, only consider current block
|
||||
if (params.schedule != DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED || (i >= block_start && i < block_end)) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (params.add_gumbel_noise && params.temperature > 0.0f) {
|
||||
add_gumbel_noise(logits, n_vocab, params.temperature, rng);
|
||||
}
|
||||
|
||||
if (params.algorithm == DIFFUSION_ALGORITHM_ORIGIN) {
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
float p_transfer = (float) transfer_count / mask_positions.size();
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
(size_t) n_vocab,
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float conf = calculate_confidence(cur_p, params.algorithm, rng);
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(conf, i);
|
||||
}
|
||||
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
|
||||
if (transfer_count > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(),
|
||||
confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
|
||||
confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
for (size_t i = 0; i < confidences.size(); i++) {
|
||||
float conf_logit = confidences[i].first / params.alg_temp;
|
||||
conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
conf_candidates.data(),
|
||||
conf_candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int32_t selected_idx = conf_array.selected;
|
||||
int32_t mask_idx = selected_idx;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0,
|
||||
total_time / 1000.0 / params.steps,
|
||||
total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = params.max_length;
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
enum diffusion_algorithm {
|
||||
DIFFUSION_ALGORITHM_ORIGIN = 0,
|
||||
DIFFUSION_ALGORITHM_ENTROPY_BASED = 1,
|
||||
DIFFUSION_ALGORITHM_MARGIN_BASED = 2,
|
||||
DIFFUSION_ALGORITHM_RANDOM = 3,
|
||||
DIFFUSION_ALGORITHM_CONFIDENCE_BASED = 4,
|
||||
};
|
||||
|
||||
// Unified transfer scheduling methods
|
||||
enum diffusion_transfer_schedule {
|
||||
DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
|
||||
DIFFUSION_TRANSFER_SCHEDULE_BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
|
||||
};
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps = 0;
|
||||
float temperature = 0;
|
||||
llama_token mask_token_id = LLAMA_TOKEN_NULL;
|
||||
diffusion_step_callback_t step_callback = nullptr;
|
||||
void * step_callback_user_data = nullptr;
|
||||
int32_t seed = 0;
|
||||
bool visual_mode = false;
|
||||
bool shift_logits = false; // Shift logits by -1 after decode
|
||||
|
||||
float top_p = 0.;
|
||||
int32_t top_k = 0.;
|
||||
|
||||
diffusion_algorithm algorithm = DIFFUSION_ALGORITHM_CONFIDENCE_BASED;
|
||||
diffusion_transfer_schedule schedule = DIFFUSION_TRANSFER_SCHEDULE_TIMESTEP_BASED;
|
||||
|
||||
float cfg_scale = 0.; // Config scale for classifier-free guidance
|
||||
float eps = 0.; // Timestep scheduling
|
||||
int32_t block_length = 0; // Block size (for block scheduling)
|
||||
float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
|
||||
bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
|
||||
|
||||
int32_t max_length = 0; // Maximum sequence length
|
||||
};
|
||||
|
||||
void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated);
|
||||
@@ -38,8 +38,12 @@ int main(int argc, char ** argv) {
|
||||
std::string result0;
|
||||
std::string result1;
|
||||
std::string result2;
|
||||
std::string result3;
|
||||
|
||||
// init
|
||||
|
||||
ggml_backend_load_all();
|
||||
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
auto * model = llama_init->model();
|
||||
@@ -213,11 +217,83 @@ int main(int argc, char ** argv) {
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
// test on-device state save/load
|
||||
auto params_ctx4 = common_context_params_to_llama(params);
|
||||
params_ctx4.n_seq_max = 2;
|
||||
llama_context * ctx4 = llama_init_from_model(model, params_ctx4);
|
||||
|
||||
llama_sampler * smpl4 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl4, llama_sampler_init_dist(params.sampling.seed));
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
n_token_count_out = 0;
|
||||
|
||||
if (!llama_state_load_file(ctx4, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx4, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
// save kv of seq 0
|
||||
std::vector<uint8_t> seq_store(llama_state_seq_get_size_ext(ctx4, 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE));
|
||||
const size_t ncopy = llama_state_seq_get_data_ext(ctx4, seq_store.data(), seq_store.size(), 0, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (ncopy != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
||||
|
||||
// erase whole kv
|
||||
llama_memory_clear(llama_get_memory(ctx4), true);
|
||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||
|
||||
// restore kv into seq 0
|
||||
const size_t nset = llama_state_seq_set_data_ext(ctx4, seq_store.data(), seq_store.size(), 1, LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
if (nset != seq_store.size()) {
|
||||
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||
return 1;
|
||||
}
|
||||
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
|
||||
}
|
||||
|
||||
// forth run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl4, ctx4, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx4, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result3 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {1}, true);
|
||||
|
||||
if (llama_decode(ctx4, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_sampler_free(smpl2);
|
||||
llama_sampler_free(smpl3);
|
||||
llama_sampler_free(smpl4);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
@@ -226,12 +302,18 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_free(ctx2);
|
||||
llama_free(ctx3);
|
||||
llama_free(ctx4);
|
||||
|
||||
if (result0 != result2) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (result0 != result3) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n%s : success\n", __func__);
|
||||
|
||||
return 0;
|
||||
|
||||
+2
-2
@@ -4,8 +4,8 @@ project("ggml" C CXX ASM)
|
||||
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 10)
|
||||
set(GGML_VERSION_PATCH 1)
|
||||
set(GGML_VERSION_MINOR 11)
|
||||
set(GGML_VERSION_PATCH 0)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
@@ -438,6 +438,12 @@ extern "C" {
|
||||
GGML_PREC_F32 = 10,
|
||||
};
|
||||
|
||||
// op hint
|
||||
enum ggml_op_hint {
|
||||
GGML_HINT_NONE = 0,
|
||||
GGML_HINT_SRC0_IS_HADAMARD = 1,
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum ggml_ftype {
|
||||
GGML_FTYPE_UNKNOWN = -1,
|
||||
@@ -1419,6 +1425,11 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_prec prec);
|
||||
|
||||
// change the hint of a matrix multiplication
|
||||
GGML_API void ggml_mul_mat_set_hint(
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_op_hint hint);
|
||||
|
||||
// indirect matrix multiplication
|
||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@@ -578,13 +578,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.22.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.24.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/releases/download/${KLEIDIAI_COMMIT_TAG}/kleidiai-${KLEIDIAI_COMMIT_TAG}-src.tar.gz")
|
||||
set(KLEIDIAI_RELEASE_ARCHIVE_MD5 "2f02ebe29573d45813e671eb304f2a00")
|
||||
|
||||
set(KLEIDIAI_FETCH_ARGS
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}
|
||||
URL_HASH MD5=${KLEIDIAI_RELEASE_ARCHIVE_MD5}
|
||||
)
|
||||
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
||||
list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW)
|
||||
|
||||
@@ -1245,6 +1245,12 @@ void ggml_compute_forward_mul_mat(
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const int32_t hint = ggml_get_op_params_i32(dst, 1);
|
||||
if (hint == GGML_HINT_SRC0_IS_HADAMARD && !params->use_ref) {
|
||||
ggml_compute_forward_fwht(params, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
|
||||
@@ -11212,3 +11212,91 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t n = ne10;
|
||||
GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2
|
||||
|
||||
const int64_t nr = ne11 * ne12 * ne13;
|
||||
const int64_t rows_per_thread = (nr + nth - 1) / nth;
|
||||
const int64_t start_row = ith * rows_per_thread;
|
||||
const int64_t end_row = MIN(start_row + rows_per_thread, nr);
|
||||
|
||||
const float scale = 1.0f / sqrtf((float)n);
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f);
|
||||
#endif
|
||||
|
||||
for (int64_t r = start_row; r < end_row; r++) {
|
||||
const int64_t i13 = r / (ne11 * ne12);
|
||||
const int64_t i12 = (r - i13 * ne11 * ne12) / ne11;
|
||||
const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11;
|
||||
|
||||
const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13);
|
||||
float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3);
|
||||
|
||||
for (int64_t j = 0; j < n; j++) {
|
||||
dst_row[j] = src_row[j] * scale;
|
||||
}
|
||||
|
||||
// Scalar passes
|
||||
#if defined(GGML_SIMD)
|
||||
const int step = GGML_F32_EPR;
|
||||
#else
|
||||
const int step = n;
|
||||
#endif
|
||||
for (int64_t len = 1; len < step && len < n; len <<= 1) {
|
||||
for (int64_t i = 0; i < n; i += 2 * len) {
|
||||
for (int64_t j = 0; j < len; j++) {
|
||||
float u = dst_row[i + j];
|
||||
float v = dst_row[i + len + j];
|
||||
dst_row[i + j] = u + v;
|
||||
dst_row[i + len + j] = u - v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SIMD passes using GGML_F32_VEC_* macros for multi-architecture support
|
||||
#if defined(GGML_SIMD)
|
||||
for (int64_t len = step; len < n; len <<= 1) {
|
||||
for (int64_t i = 0; i < n; i += 2 * len) {
|
||||
for (int64_t j = 0; j < len; j += step) {
|
||||
GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j);
|
||||
GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j);
|
||||
|
||||
GGML_F32_VEC_STORE(dst_row + i + j, GGML_F32_VEC_ADD(u, v));
|
||||
GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_fwht_f32(params, dst);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error - fwht is F32 only");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +111,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
|
||||
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_fwht(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -6,17 +6,18 @@ template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(
|
||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
|
||||
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = z / ne12; // TODO fastdiv
|
||||
const int i12 = z % ne12;
|
||||
const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv);
|
||||
const int i11 = dm.x;
|
||||
const int i12 = dm.y;
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
@@ -42,17 +43,18 @@ template<typename src0_t, typename dst_t>
|
||||
static __global__ void k_get_rows_float(
|
||||
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
|
||||
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = z / ne12; // TODO fastdiv
|
||||
const int i12 = z % ne12;
|
||||
const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv);
|
||||
const int i11 = dm.x;
|
||||
const int i12 = dm.y;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
@@ -115,10 +117,14 @@ static void get_rows_cuda_q(
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
GGML_ASSERT(ne12 > 0);
|
||||
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
|
||||
const uint3 ne12_fdv = init_fastdiv_values(ne12);
|
||||
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
@@ -146,10 +152,14 @@ static void get_rows_cuda_float(
|
||||
const size_t s12 = nb12 / sizeof(int32_t);
|
||||
// const size_t s13 = nb13 / sizeof(int32_t);
|
||||
|
||||
GGML_ASSERT(ne12 > 0);
|
||||
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
|
||||
const uint3 ne12_fdv = init_fastdiv_values(ne12);
|
||||
|
||||
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
|
||||
@@ -5431,8 +5431,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
char pci_bus_id[16] = {};
|
||||
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
||||
char pci_bus_id[32] = {};
|
||||
CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i));
|
||||
dev_ctx->pci_bus_id = pci_bus_id;
|
||||
dev_ctx->op_offload_min_batch_size = min_batch_size;
|
||||
|
||||
|
||||
Vendored
+1
@@ -55,6 +55,7 @@
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaDeviceGetPCIBusId hipDeviceGetPCIBusId
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
|
||||
Vendored
+1
@@ -39,6 +39,7 @@
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetPCIBusId musaDeviceGetPCIBusId
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaError_t musaError_t
|
||||
|
||||
@@ -22,7 +22,8 @@ message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
|
||||
|
||||
|
||||
@@ -2254,8 +2254,7 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->ne[2] != 1 || dst->ne[3] != 1) {
|
||||
// FA during prompt still needs work
|
||||
if (dst->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2421,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
|
||||
// TODO: add support for non-contiguous elements within a row
|
||||
if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,11 @@ target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
# HMX acceleration: available on v73+ architectures
|
||||
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
|
||||
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
@@ -52,11 +57,13 @@ if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
|
||||
@@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||
|
||||
#Compiler Options
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
|
||||
|
||||
set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2")
|
||||
|
||||
set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
|
||||
set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")
|
||||
|
||||
@@ -17,13 +17,14 @@
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hmx-ops.h"
|
||||
|
||||
// Must be multiple of 32
|
||||
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
||||
|
||||
// This is a bit of a hack because the compiler is strugling to properly inline
|
||||
// the default hvx_vec_f32_to_f16 with output into the local array.
|
||||
static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
{
|
||||
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
|
||||
}
|
||||
@@ -621,6 +622,17 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) {
|
||||
int ret = hmx_flash_attn_ext(octx);
|
||||
if (ret == HTP_STATUS_OK) {
|
||||
return ret;
|
||||
}
|
||||
// VTCM too small or other failure -> fall through to HVX path
|
||||
}
|
||||
#endif
|
||||
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
|
||||
@@ -74,6 +74,12 @@ static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -61,6 +61,9 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx,
|
||||
int m, int k, int n,
|
||||
int weight_type);
|
||||
|
||||
// HMX flash attention
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
#ifndef HMX_UTILS_H
|
||||
#define HMX_UTILS_H
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <stddef.h>
|
||||
|
||||
@@ -12,21 +15,188 @@
|
||||
#define HMX_FP16_TILE_N_ELMS 1024
|
||||
#define HMX_FP16_TILE_SIZE 2048
|
||||
|
||||
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
|
||||
|
||||
// Initialise aligned 256-byte area with scale vector + zero padding.
|
||||
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
HVX_Vector *pv = (HVX_Vector *)out_scales;
|
||||
*pv++ = v_scale;
|
||||
*pv = Q6_V_vzero();
|
||||
static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
volatile HVX_Vector *pv = (HVX_Vector *) out_scales;
|
||||
pv[0] = v_scale;
|
||||
pv[1] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
|
||||
// --- Shared scatter offsets and interleave helper ---
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile.
|
||||
// word[i] = i*128 maps K-row-pair i to byte offset i*128.
|
||||
// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047);
|
||||
// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the
|
||||
// call site to scatter into one tile (masked) or two contiguous tiles (unmasked).
|
||||
static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
|
||||
0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128,
|
||||
11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128,
|
||||
22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128,
|
||||
};
|
||||
|
||||
// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles.
|
||||
// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used)
|
||||
// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16
|
||||
// Processes rows [start_row, end_row) for multi-thread slicing.
|
||||
// Full range: start_row=0, end_row=n_cols.
|
||||
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const __fp16 * restrict vtcm_src,
|
||||
int n_cols,
|
||||
int k,
|
||||
int src_stride,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
assert(k % HMX_FP16_TILE_N_COLS == 0);
|
||||
|
||||
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
// Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles.
|
||||
// When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask)
|
||||
// using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when
|
||||
// n_k_tiles is odd) falls back to single-tile masked scatter.
|
||||
const bool pair_scatter = (n_k_tiles & 1) == 0;
|
||||
const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1);
|
||||
const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1);
|
||||
__builtin_assume(k > 0);
|
||||
__builtin_assume(end_row > start_row);
|
||||
|
||||
if (pair_scatter) {
|
||||
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
|
||||
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: scatter one K-tile per call (region 2047, masked).
|
||||
const int c_step = HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Interleave row-major FP16 data into column-major tile format.
|
||||
// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile].
|
||||
// Processes rows [start_row, end_row) for multi-thread slicing.
|
||||
// Full range: start_row=0, end_row=n_rows.
|
||||
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
const __fp16 * restrict src,
|
||||
int n_rows,
|
||||
int head_dim,
|
||||
int src_stride,
|
||||
int n_row_tiles,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
__builtin_assume(head_dim > 0);
|
||||
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
|
||||
|
||||
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
|
||||
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
// Row-pair invariants hoisted out of the c loop.
|
||||
const int r0 = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
|
||||
|
||||
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
|
||||
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
|
||||
__fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS;
|
||||
__fp16 * tb1 = tb0 + tile_stride_elms;
|
||||
const size_t tb_step = 2 * tile_stride_elms;
|
||||
|
||||
if (pv_in1) {
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_Vector v1 = *pv_in1++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
|
||||
tb0 += tb_step;
|
||||
tb1 += tb_step;
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp);
|
||||
tb0 += tb_step;
|
||||
tb1 += tb_step;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // HMX_UTILS_H
|
||||
|
||||
@@ -77,6 +77,12 @@ static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline _Float16 hvx_vec_get_f16(HVX_Vector v) {
|
||||
_Float16 __attribute__((aligned(128))) x;
|
||||
hvx_vec_store_a(&x, 2, v);
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
|
||||
// abs by clearing the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
#define hvx_splat_loop_body(dst_type, vec_store) \
|
||||
#define hvx_splat_pragma(x) _Pragma(#x)
|
||||
#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
\
|
||||
@@ -16,7 +17,7 @@
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
hvx_splat_pragma(unroll(unroll_cnt)) \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = src; \
|
||||
} \
|
||||
@@ -25,31 +26,47 @@
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);
|
||||
hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);
|
||||
static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
|
||||
hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {
|
||||
static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) {
|
||||
static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) {
|
||||
static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) {
|
||||
hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) {
|
||||
hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) {
|
||||
hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t));
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) {
|
||||
hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1);
|
||||
}
|
||||
|
||||
static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) {
|
||||
hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1);
|
||||
}
|
||||
|
||||
#define hvx_copy_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x42B16666) // 88.7
|
||||
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
|
||||
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
|
||||
@@ -163,7 +163,7 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
|
||||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
static const float kInf = INFINITY;
|
||||
static const float kMaxExp = 88.7f;
|
||||
static const float kMaxExp = 88.7228f;
|
||||
|
||||
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
|
||||
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
|
||||
|
||||
@@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
// M alignment: when M > 32 but not 32-aligned, we split into
|
||||
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
|
||||
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
|
||||
// M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows)
|
||||
// is handled by HMX itself; when M < 32 fall back to HVX.
|
||||
const int m_total = (int) src1->ne[1];
|
||||
const int m_tail = m_total % 32;
|
||||
const int m_hmx = m_total - m_tail;
|
||||
const int m_hmx = m_total & ~31; // 0 when M < 32
|
||||
|
||||
if (m_hmx == 0) {
|
||||
return op_matmul_hvx(octx);
|
||||
@@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
int k = (int) src0->ne[0]; // inner dimension
|
||||
int n = (int) src0->ne[1]; // weight columns
|
||||
|
||||
// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
|
||||
int ret = -1;
|
||||
|
||||
// Row strides in elements. For compact tensors these equal k; for
|
||||
@@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
.dst = (float *) dst->data,
|
||||
.activation = (float *) src1->data,
|
||||
.permuted_weight = (const __fp16 *) src0->data,
|
||||
.m = m_hmx,
|
||||
.m = m_total,
|
||||
.k = k,
|
||||
.n = n,
|
||||
.act_stride = act_stride,
|
||||
@@ -3048,12 +3045,12 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
|
||||
m_hmx, k, n, act_stride, wgt_stride);
|
||||
m_total, k, n, act_stride, wgt_stride);
|
||||
}
|
||||
} else {
|
||||
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
|
||||
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
|
||||
m_hmx, k, n, (int) src0->type);
|
||||
m_total, k, n, (int) src0->type);
|
||||
}
|
||||
|
||||
if (ret != 0) {
|
||||
@@ -3061,27 +3058,6 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
return op_matmul(octx);
|
||||
}
|
||||
|
||||
// --- Phase 2: HVX on the remaining m_tail rows ---
|
||||
if (m_tail > 0) {
|
||||
// copy of src1 and dst
|
||||
struct htp_tensor src1_tail = *src1;
|
||||
struct htp_tensor dst_tail = *dst;
|
||||
|
||||
src1_tail.ne[1] = m_tail; // only tail rows
|
||||
dst_tail.ne[1] = m_tail; // only tail rows
|
||||
|
||||
// Offset activation and dst pointers past the HMX-processed rows.
|
||||
// Use nb[1] (row stride in bytes) to compute the byte offset.
|
||||
src1_tail.data += (uint32_t) m_hmx * src1->nb[1];
|
||||
dst_tail.data += (uint32_t) m_hmx * dst->nb[1];
|
||||
|
||||
octx->src[1] = &src1_tail;
|
||||
octx->dst = &dst_tail;
|
||||
|
||||
FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data);
|
||||
return op_matmul_hvx(octx);
|
||||
}
|
||||
|
||||
return 0;
|
||||
#endif // HTP_HAS_HMX
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ struct htp_unary_context {
|
||||
const uint8_t * data_src0;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
size_t src0_data_row_size; // actual data bytes per row
|
||||
size_t dst_data_row_size; // actual data bytes per row
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
@@ -41,6 +41,40 @@ struct htp_unary_context {
|
||||
uint32_t nc;
|
||||
};
|
||||
|
||||
// Convert flat row index to DDR byte offset using the tensor's actual strides.
|
||||
// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3
|
||||
static inline size_t unary_row_offset(uint32_t ir,
|
||||
uint32_t ne1, uint32_t ne2,
|
||||
size_t nb1, size_t nb2, size_t nb3) {
|
||||
const uint32_t i1 = ir % ne1;
|
||||
const uint32_t i2 = (ir / ne1) % ne2;
|
||||
const uint32_t i3 = ir / (ne1 * ne2);
|
||||
return i1 * nb1 + i2 * nb2 + i3 * nb3;
|
||||
}
|
||||
// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice
|
||||
// boundary of src and dst so the nb1 stride stays valid for all rows.
|
||||
static inline uint32_t unary_block_size(uint32_t ir,
|
||||
uint32_t end_row,
|
||||
uint32_t block,
|
||||
bool src_contig,
|
||||
bool dst_contig,
|
||||
uint32_t src_ne1,
|
||||
uint32_t dst_ne1) {
|
||||
uint32_t limit = MIN(block, end_row - ir);
|
||||
|
||||
if (!src_contig) {
|
||||
const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1;
|
||||
limit = MIN(limit, src_slice_end - ir);
|
||||
}
|
||||
|
||||
if (!dst_contig) {
|
||||
const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1;
|
||||
limit = MIN(limit, dst_slice_end - ir);
|
||||
}
|
||||
|
||||
return limit;
|
||||
}
|
||||
|
||||
#define htp_unary_preamble \
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
const uint32_t ne01 = src->ne[1]; \
|
||||
@@ -276,8 +310,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
int32_t * op_params = octx->op_params;
|
||||
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
|
||||
|
||||
const size_t src0_row_size = uctx->src0_row_size;
|
||||
const size_t dst_row_size = uctx->dst_row_size;
|
||||
const size_t src0_data_row_size = uctx->src0_data_row_size;
|
||||
const size_t dst_data_row_size = uctx->dst_data_row_size;
|
||||
|
||||
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
|
||||
@@ -303,7 +337,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
size_t src0_spad_half_size = uctx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = uctx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = uctx->block;
|
||||
// Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride
|
||||
// 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every
|
||||
// transfer stays within a nb1-uniform region. Skipped for contiguous tensors.
|
||||
const bool src0_contig = (nb02 == (size_t)ne01 * nb01) &&
|
||||
(nb03 == (size_t)ne02 * nb02);
|
||||
const bool dst_contig = (nb2 == (size_t)ne1 * nb1) &&
|
||||
(nb3 == (size_t)ne2 * nb2);
|
||||
const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01);
|
||||
const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1);
|
||||
const uint32_t BLOCK = MIN(src0_max_block, dst_max_block);
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
@@ -312,21 +355,23 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) {
|
||||
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
nb1, dst_row_size_aligned, dst_data_row_size, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off),
|
||||
src0_row_size_aligned, nb01, src0_data_row_size, block_size);
|
||||
ir += block_size;
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ) {
|
||||
const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
@@ -361,18 +406,25 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
|
||||
break;
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
|
||||
dst_row_size, dst_row_size_aligned, block_size);
|
||||
const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(data_dst + dst_off, dst_spad),
|
||||
nb1, dst_row_size_aligned, dst_data_row_size, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
const uint32_t next_ir = ir + block_size;
|
||||
if (next_ir < src0_end_row) {
|
||||
const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
const uint32_t pref_ir = next_ir + next_block_size;
|
||||
if (pref_ir < src0_end_row) {
|
||||
const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1);
|
||||
const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03);
|
||||
dma_queue_push(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + src0_pref_off),
|
||||
src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
ir += block_size;
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
@@ -426,11 +478,11 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
const size_t src0_data_row_size = src0->ne[0] * sizeof(float);
|
||||
const size_t dst_data_row_size = dst->ne[0] * sizeof(float);
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN);
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
@@ -468,8 +520,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
.data_src0 = (const uint8_t *)src0->data,
|
||||
.data_dst = (uint8_t *)dst->data,
|
||||
|
||||
.src0_row_size = src0_row_size,
|
||||
.dst_row_size = dst_row_size,
|
||||
.src0_data_row_size = src0_data_row_size,
|
||||
.dst_data_row_size = dst_data_row_size,
|
||||
|
||||
.src0_row_size_aligned = src0_row_size_aligned,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
#ifndef VTCM_UTILS_H
|
||||
#define VTCM_UTILS_H
|
||||
|
||||
#include "hex-utils.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
}
|
||||
|
||||
#endif // VTCM_UTILS_H
|
||||
@@ -282,6 +282,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf);
|
||||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
bool ggml_metal_buffer_cpy_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||
void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value);
|
||||
|
||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#import "ggml-metal-device.h"
|
||||
|
||||
#import "ggml-impl.h"
|
||||
#import "ggml-backend-impl.h"
|
||||
|
||||
#include <Foundation/Foundation.h>
|
||||
|
||||
@@ -1737,6 +1738,47 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_metal_buffer_cpy_tensor(ggml_metal_buffer_t buf_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
ggml_metal_buffer_t buf_src = (ggml_metal_buffer_t)src->buffer->context;
|
||||
|
||||
const size_t size = ggml_nbytes(src);
|
||||
|
||||
// if both buffers are shared, we can use memcpy directly
|
||||
if (buf_dst->is_shared && buf_src->is_shared) {
|
||||
memcpy(dst->data, src->data, size);
|
||||
return true;
|
||||
}
|
||||
|
||||
// for private buffers, we need to use Metal blit commands
|
||||
@autoreleasepool {
|
||||
struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf_src, src);
|
||||
struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf_dst, dst);
|
||||
|
||||
if (bid_src.metal == nil || bid_dst.metal == nil) {
|
||||
return false;
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> cmd_buf = [buf_dst->dev->mtl_queue commandBufferWithUnretainedReferences];
|
||||
|
||||
{
|
||||
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
||||
|
||||
[encoder copyFromBuffer:bid_src.metal
|
||||
sourceOffset:bid_src.offs
|
||||
toBuffer:bid_dst.metal
|
||||
destinationOffset:bid_dst.offs
|
||||
size:size];
|
||||
|
||||
[encoder endEncoding];
|
||||
}
|
||||
|
||||
[cmd_buf commit];
|
||||
[cmd_buf waitUntilCompleted];
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) {
|
||||
if (buf->is_shared) {
|
||||
memset(buf->all_data, value, buf->all_size);
|
||||
|
||||
@@ -17,6 +17,9 @@
|
||||
// note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices
|
||||
static int g_devices = 1;
|
||||
|
||||
// forward declaration
|
||||
static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// backend interface
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -68,11 +71,11 @@ static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t bu
|
||||
|
||||
GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
if (!ggml_backend_buffer_is_metal(src->buffer)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
return ggml_metal_buffer_cpy_tensor(ctx, src, dst);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
@@ -144,11 +147,11 @@ static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t b
|
||||
|
||||
GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
if (!ggml_backend_buffer_is_metal(src->buffer)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
return ggml_metal_buffer_cpy_tensor(ctx, src, dst);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
|
||||
@@ -66,8 +66,6 @@ set(GGML_OPENCL_KERNELS
|
||||
diag
|
||||
div
|
||||
gelu
|
||||
gemv_noshuffle_general
|
||||
gemv_noshuffle
|
||||
get_rows
|
||||
glu
|
||||
group_norm
|
||||
@@ -75,7 +73,6 @@ set(GGML_OPENCL_KERNELS
|
||||
im2col_f32
|
||||
im2col_f16
|
||||
mean
|
||||
mul_mat_Ab_Bi_8x4
|
||||
mul_mv_f16_f16
|
||||
mul_mv_f16_f32_1row
|
||||
mul_mv_f16_f32_l4
|
||||
@@ -107,6 +104,10 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_id_mxfp4_f32_flat
|
||||
gemm_moe_mxfp4_f32
|
||||
gemv_moe_mxfp4_f32
|
||||
gemm_moe_mxfp4_f32_ns
|
||||
gemv_moe_mxfp4_f32_ns
|
||||
moe_reorder_b
|
||||
moe_sort_by_expert
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q4_0_f32_l4_lm
|
||||
@@ -116,12 +117,15 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_q4_k_f32_l4_lm
|
||||
mul_mm_q5_k_f32_l4_lm
|
||||
mul_mm_q6_k_f32_l4_lm
|
||||
mul_mm_q8_0_f32_8x4
|
||||
gemv_noshuffle_q4_0_f32
|
||||
gemv_noshuffle_q4_0_f32_spec
|
||||
gemm_noshuffle_q4_0_f32
|
||||
gemv_noshuffle_q4_1_f32
|
||||
gemm_noshuffle_q4_1_f32
|
||||
gemv_noshuffle_iq4_nl_f32
|
||||
gemm_noshuffle_iq4_nl_f32
|
||||
gemv_noshuffle_general_q8_0_f32
|
||||
gemv_noshuffle_q8_0_f32
|
||||
gemm_noshuffle_q8_0_f32
|
||||
gemv_noshuffle_q4_k_f32
|
||||
gemm_noshuffle_q4_k_f32
|
||||
gemv_noshuffle_q6_k_f32
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -371,6 +371,93 @@ kernel void kernel_restore_block_mxfp4_trans(
|
||||
b->e = src_e[src_blk_offset];
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_mxfp4_trans4_ns(
|
||||
global struct block_mxfp4 * src0,
|
||||
__global uint * dst_q,
|
||||
__global uchar * dst_e,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
global struct block_mxfp4 * b = src0 + src_blk_offset;
|
||||
dst_e[dst_blk_offset] = b->e;
|
||||
|
||||
// extract quantization and unshuffle
|
||||
ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0];
|
||||
|
||||
ushort8 post_block = (ushort8)(0);
|
||||
|
||||
uchar * pre_block_ptr = (uchar *)(&pre_block);
|
||||
uchar * post_block_ptr = (uchar *)(&post_block);
|
||||
|
||||
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
|
||||
uchar x0 = pre_block_ptr[2*i + 0];
|
||||
uchar x1 = pre_block_ptr[2*i + 1];
|
||||
|
||||
post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
|
||||
post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
|
||||
}
|
||||
|
||||
uint4 q_block = as_uint4(post_block);
|
||||
|
||||
uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
|
||||
dst_q[offset] = q_block.x;
|
||||
dst_q[offset + ne01] = q_block.y;
|
||||
dst_q[offset + ne01 * 2] = q_block.z;
|
||||
dst_q[offset + ne01 * 3] = q_block.w;
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_mxfp4_trans4_ns(
|
||||
__global uint * src_q,
|
||||
__global uchar * src_e,
|
||||
__global struct block_mxfp4 * dst0,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
__global struct block_mxfp4 * b = dst0 + dst_blk_offset;
|
||||
b->e = src_e[src_d_offset];
|
||||
|
||||
// collect transposed quantization parts for a block
|
||||
uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
|
||||
uint4 q_block;
|
||||
q_block.x = src_q[src_q_offset];
|
||||
q_block.y = src_q[src_q_offset + ne01];
|
||||
q_block.z = src_q[src_q_offset + ne01 * 2];
|
||||
q_block.w = src_q[src_q_offset + ne01 * 3];
|
||||
|
||||
ushort8 post_block = as_ushort8(q_block);
|
||||
ushort8 pre_block = (ushort8)(0);
|
||||
|
||||
uchar * pre_block_ptr = (uchar *)(&pre_block);
|
||||
uchar * post_block_ptr = (uchar *)(&post_block);
|
||||
|
||||
for (int i = 0; i < QK_MXFP4 / 4; ++i) {
|
||||
uchar x0 = post_block_ptr[i + 0];
|
||||
uchar x1 = post_block_ptr[i + QK_MXFP4 / 4];
|
||||
|
||||
pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
|
||||
pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
|
||||
}
|
||||
|
||||
((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
|
||||
}
|
||||
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q8_0
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
|
||||
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
|
||||
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
|
||||
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
|
||||
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
|
||||
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
|
||||
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
|
||||
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
|
||||
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
|
||||
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
|
||||
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
|
||||
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
|
||||
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
|
||||
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
|
||||
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
|
||||
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
|
||||
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
|
||||
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
|
||||
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
|
||||
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
|
||||
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
|
||||
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
|
||||
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
|
||||
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
|
||||
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
|
||||
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
|
||||
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
|
||||
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
|
||||
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
|
||||
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
|
||||
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
|
||||
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
|
||||
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
|
||||
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
|
||||
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
|
||||
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
|
||||
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
|
||||
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
|
||||
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
|
||||
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
|
||||
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
|
||||
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
|
||||
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
|
||||
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
|
||||
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
|
||||
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
|
||||
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
|
||||
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
|
||||
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
|
||||
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
|
||||
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
|
||||
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
|
||||
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
|
||||
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
|
||||
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
|
||||
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
|
||||
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
|
||||
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
|
||||
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
|
||||
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
|
||||
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
|
||||
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
|
||||
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
|
||||
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
|
||||
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
|
||||
|
||||
static inline half e8m0_to_fp16(uchar x) {
|
||||
ushort bits;
|
||||
bits = (ushort)(x) - (ushort)(112);
|
||||
bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10);
|
||||
return as_half(bits);
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
|
||||
kernel void kernel_gemm_moe_mxfp4_f32_ns(
|
||||
__read_only image1d_buffer_t src0_q,
|
||||
__global uchar * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global ushort * src2_emap,
|
||||
__write_only image1d_buffer_t dst,
|
||||
__global int * total_tiles,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint block_id_m = get_global_id(1); // m_tile
|
||||
uint block_id_n = get_global_id(2); // n_tile
|
||||
|
||||
// Boundary check
|
||||
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
|
||||
return;
|
||||
}
|
||||
|
||||
__private half16 reg_a;
|
||||
__private float32 reg_c = (float32)(0);
|
||||
__local half4 shared_b[128];
|
||||
|
||||
const ushort expert_id = src2_emap[block_id_n];
|
||||
|
||||
const uint row = block_id_m * TILESIZE_M;
|
||||
const uint col = block_id_n * TILESIZE_N;
|
||||
|
||||
uint sub_block_id_m = get_local_id(0);
|
||||
uint2 b_global_offset;
|
||||
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
|
||||
b_global_offset.y = b_global_offset.x + (16 * ne00);
|
||||
uint2 b_local_offset;
|
||||
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
|
||||
b_local_offset.y = b_local_offset.x + 16;
|
||||
|
||||
// Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
|
||||
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
|
||||
// First sub-block
|
||||
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
|
||||
uint b_sub_offset = col * ne00 + step;
|
||||
|
||||
// Load scale for current mxfp4 block
|
||||
uint s_offset = s_sub_offset + get_global_id(0);
|
||||
float s = e8m0_to_fp32(src0_d[s_offset]);
|
||||
|
||||
// Load 16 fp4 (64-bits) in transposed layout
|
||||
uint2 mxfp4x16;
|
||||
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
float8 bx8_f32;
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
half8 bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
|
||||
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 8 elements reduction for better precision
|
||||
half16 acc;
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
|
||||
// Repeat for second sub-block
|
||||
uint half_step = step + TILESIZE_K;
|
||||
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
b_sub_offset = col * ne00 + half_step;
|
||||
|
||||
// Load next 16 fp4 (64-bits) in transposed layout
|
||||
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
|
||||
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 3-levels reduction for better precision
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
}
|
||||
|
||||
// Load poster router and share in LM
|
||||
__local uint out_idx[TILESIZE_N];
|
||||
|
||||
if (get_local_id(0) < TILESIZE_N) {
|
||||
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
|
||||
if (idx == 0xFFFFFFFF) {
|
||||
idx = src2[block_id_n * TILESIZE_N + 0];
|
||||
}
|
||||
out_idx[get_local_id(0)] = idx * ne01;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Scatter results back to original position in output grid
|
||||
uint m_offset = row + get_local_id(0);
|
||||
|
||||
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
|
||||
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
|
||||
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
|
||||
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
|
||||
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
|
||||
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
|
||||
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
|
||||
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
|
||||
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
|
||||
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
|
||||
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
|
||||
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
|
||||
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
|
||||
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
|
||||
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
|
||||
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
|
||||
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
|
||||
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
|
||||
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
|
||||
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
|
||||
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
|
||||
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
|
||||
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
|
||||
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
|
||||
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
|
||||
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
|
||||
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
|
||||
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
|
||||
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
|
||||
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
|
||||
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
|
||||
|
||||
// Store zero padding parts to the index of first output in tile, override correct result in the end
|
||||
barrier(CLK_GLOBAL_MEM_FENCE);
|
||||
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
|
||||
}
|
||||
+1
-1
@@ -17,7 +17,7 @@
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
kernel void kernel_mul_mat_Ab_Bi_8x4(
|
||||
kernel void kernel_gemm_noshuffle_q4_0_f32(
|
||||
global const ushort * src0_q, // quantized A
|
||||
global const half * src0_d, // A scales
|
||||
__read_only image1d_buffer_t src1, // B (1d image)
|
||||
+1
-1
@@ -11,7 +11,7 @@
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
kernel void kernel_mul_mm_q8_0_f32_8x4(
|
||||
kernel void kernel_gemm_noshuffle_q8_0_f32(
|
||||
global const uint * src0_q,
|
||||
global const half * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
@@ -0,0 +1,161 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_MXFP4 32
|
||||
#define N_SIMDGROUP 4
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemv_moe_mxfp4_f32_ns(
|
||||
__global uint * src0_q,
|
||||
__global uchar * src0_e,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne11
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
uint i11 = i20 % ne11;
|
||||
|
||||
uint expert_id = src2[i20];
|
||||
uint expert_offset = expert_id * ne00 * ne01 / 32;
|
||||
|
||||
__private float sum = 0.0f; // each thread calculate partial sum of one output
|
||||
|
||||
// loop along ne00 in block granularity, skip 4 blocks every iter
|
||||
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
|
||||
|
||||
// load one block of q
|
||||
uint4 regQ;
|
||||
uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01;
|
||||
|
||||
regQ.s0 = src0_q[block_offset];
|
||||
regQ.s1 = src0_q[block_offset + ne01];
|
||||
regQ.s2 = src0_q[block_offset + ne01 * 2];
|
||||
regQ.s3 = src0_q[block_offset + ne01 * 3];
|
||||
|
||||
uint offset = i11 * ne00 / 4 + ib00 * 8;
|
||||
|
||||
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
|
||||
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (offset + 0));
|
||||
float4 acc = shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 1));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 2));
|
||||
acc += shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 3));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 4));
|
||||
acc += shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 5));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 6));
|
||||
acc += shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 7));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
|
||||
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 outputs per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + i20 * ne01] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
+5
-5
@@ -191,7 +191,7 @@
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
__kernel void kernel_gemv_noshuffle(
|
||||
__kernel void kernel_gemv_noshuffle_q4_0_f32(
|
||||
__read_only image1d_buffer_t src0_q, // quantized A
|
||||
global half2 * src0_d, // A scales
|
||||
__read_only image1d_buffer_t src1, // B
|
||||
@@ -238,21 +238,21 @@ __kernel void kernel_gemv_noshuffle(
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAT
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAT
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAT
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAT
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
+5
-5
@@ -191,7 +191,7 @@
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
__kernel void kernel_gemv_noshuffle(
|
||||
__kernel void kernel_gemv_noshuffle_q4_0_f32(
|
||||
__read_only image1d_buffer_t src0_q, // quantized A
|
||||
global half2 * src0_d, // A scales
|
||||
__read_only image1d_buffer_t src1, // B
|
||||
@@ -232,21 +232,21 @@ __kernel void kernel_gemv_noshuffle(
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAT
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAT
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
|
||||
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAT
|
||||
#ifdef VECTOR_SUB_GROUP_BROADCAST
|
||||
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
|
||||
#else
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAT
|
||||
#endif // VECTOR_SUB_GROUP_BROADCAST
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
@@ -0,0 +1,30 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define QK4_0 32
|
||||
|
||||
kernel void kernel_moe_reorder_b(
|
||||
global float4 * src,
|
||||
global uint * router,
|
||||
global float4 * dst,
|
||||
global int * total_tiles,
|
||||
uint K,
|
||||
ushort map_ratio,
|
||||
uint tile_size
|
||||
) {
|
||||
uint k_4 = get_global_id(0);
|
||||
uint post_router_idx = get_global_id(1);
|
||||
|
||||
if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint router_idx = router[post_router_idx];
|
||||
|
||||
float4 out = (float4)(0);
|
||||
if (router_idx != 0xFFFFFFFF) {
|
||||
ushort activation_idx = router_idx / map_ratio;
|
||||
out = src[activation_idx * K / 4 + k_4];
|
||||
}
|
||||
|
||||
dst[post_router_idx * K / 4 + k_4] = out;
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
__kernel void kernel_moe_histogram(
|
||||
__global const int * input,
|
||||
__global int * hist,
|
||||
uint N,
|
||||
uint topK,
|
||||
uint n_experts
|
||||
) {
|
||||
uint n = get_global_id(0);
|
||||
uint k = get_global_id(1);
|
||||
|
||||
if (n >= N || k >= topK) {
|
||||
return;
|
||||
}
|
||||
|
||||
int expert_id = input[n * n_experts + k];
|
||||
atomic_inc(&hist[expert_id]);
|
||||
}
|
||||
|
||||
__kernel void kernel_moe_scan(
|
||||
__global int * hist,
|
||||
__global int * tile_offset,
|
||||
__global int * total_tiles,
|
||||
__global int * slot_counter,
|
||||
int tile_size,
|
||||
uint n_experts
|
||||
) {
|
||||
int offset = 0;
|
||||
for (int v = 0; v < n_experts; v++) {
|
||||
int count = hist[v];
|
||||
int tiles = (count + tile_size - 1) / tile_size;
|
||||
tile_offset[v] = offset;
|
||||
offset += tiles;
|
||||
hist[v] = 0;
|
||||
slot_counter[v] = 0;
|
||||
}
|
||||
|
||||
*total_tiles = offset;
|
||||
}
|
||||
|
||||
__kernel void kernel_moe_scatter(
|
||||
__global const int * input,
|
||||
__global int * post_router,
|
||||
__global ushort * emap,
|
||||
__global const int * tile_offset,
|
||||
__global int * slot_counter,
|
||||
int N,
|
||||
int topK,
|
||||
uint n_experts
|
||||
) {
|
||||
uint n = get_global_id(0);
|
||||
uint k = get_global_id(1);
|
||||
|
||||
if (n >= N || k >= topK) {
|
||||
return;
|
||||
}
|
||||
|
||||
int val = input[n * n_experts + k];
|
||||
|
||||
int local_slot = atomic_inc(&slot_counter[val]);
|
||||
|
||||
int tile_idx = tile_offset[val] + (local_slot / 32);
|
||||
int lane = local_slot % 32;
|
||||
int out_pos = tile_idx * 32 + lane;
|
||||
|
||||
post_router[out_pos] = n * topK + k;
|
||||
emap[tile_idx] = val;
|
||||
}
|
||||
|
||||
__kernel void kernel_moe_fill(
|
||||
__global int * post_router,
|
||||
__global int * total_tiles,
|
||||
int tile_size
|
||||
) {
|
||||
int tile_id = get_global_id(0);
|
||||
int vec_id_in_tile = get_global_id(1);
|
||||
|
||||
if (tile_id < total_tiles[0]) {
|
||||
post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF;
|
||||
}
|
||||
}
|
||||
@@ -207,35 +207,11 @@ struct ggml_backend_rpc_buffer_type_context {
|
||||
size_t max_size;
|
||||
};
|
||||
|
||||
struct graph_cache {
|
||||
|
||||
bool is_cached(const ggml_cgraph * cgraph) {
|
||||
if ((int)last_graph.size() != cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void add(const ggml_cgraph * cgraph) {
|
||||
last_graph.resize(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ggml_tensor> last_graph;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_context {
|
||||
std::string endpoint;
|
||||
uint32_t device;
|
||||
std::string name;
|
||||
graph_cache gc;
|
||||
uint64_t last_graph_uid;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_buffer_context {
|
||||
@@ -717,7 +693,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
|
||||
GGML_ASSERT(cgraph->n_nodes > 0);
|
||||
bool reuse = rpc_ctx->gc.is_cached(cgraph);
|
||||
bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid;
|
||||
if (reuse) {
|
||||
rpc_msg_graph_recompute_req request;
|
||||
request.device = rpc_ctx->device;
|
||||
@@ -725,7 +701,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
} else {
|
||||
rpc_ctx->gc.add(cgraph);
|
||||
rpc_ctx->last_graph_uid = cgraph->uid;
|
||||
std::vector<uint8_t> input;
|
||||
serialize_graph(rpc_ctx->device, cgraph, input);
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
@@ -791,10 +767,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
|
||||
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
|
||||
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
|
||||
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ device,
|
||||
/* .name = */ dev_name,
|
||||
/* .gc = */ {},
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ device,
|
||||
/* .name = */ dev_name,
|
||||
/* .last_graph_uid = */ 0,
|
||||
};
|
||||
auto reg = ggml_backend_rpc_add_server(endpoint);
|
||||
ggml_backend_t backend = new ggml_backend {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "virtgpu-shm.h"
|
||||
|
||||
#include "virtgpu.h"
|
||||
#include "ggml-remoting.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "virtgpu.h"
|
||||
#include "ggml-remoting.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
|
||||
@@ -18,8 +18,6 @@
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "ggml-remoting.h"
|
||||
|
||||
#define VIRGL_RENDERER_UNSTABLE_APIS 1
|
||||
#include "apir_hw.h"
|
||||
#include <drm/virtgpu_drm.h>
|
||||
|
||||
@@ -111,8 +111,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||
|
||||
#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
|
||||
|
||||
#define GGML_VK_MAX_NODES 8192
|
||||
|
||||
#define VK_CHECK(err, msg) \
|
||||
do { \
|
||||
vk::Result err_ = (err); \
|
||||
@@ -440,10 +438,12 @@ struct vk_fa_pipeline_state {
|
||||
bool f32acc;
|
||||
uint32_t flags;
|
||||
uint32_t limit_occupancy_shmem;
|
||||
ggml_type k_type;
|
||||
ggml_type v_type;
|
||||
|
||||
bool operator<(const vk_fa_pipeline_state &b) const {
|
||||
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
|
||||
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
|
||||
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) <
|
||||
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -3041,7 +3041,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device
|
||||
return result;
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
GGML_UNUSED(n_kv);
|
||||
GGML_UNUSED(f32acc);
|
||||
|
||||
@@ -3055,7 +3055,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
|
||||
if (small_rows) {
|
||||
result.block_rows = 32;
|
||||
result.block_cols = 32;
|
||||
} else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
|
||||
} else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) {
|
||||
result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
|
||||
result.block_cols = 32;
|
||||
} else {
|
||||
@@ -3069,7 +3069,13 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
|
||||
return result;
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
||||
// Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1.
|
||||
if (k_type != v_type) {
|
||||
GGML_ASSERT(device->coopmat2);
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
}
|
||||
|
||||
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
|
||||
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||
|
||||
@@ -3081,7 +3087,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
||||
if (path == FA_COOPMAT1) {
|
||||
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
|
||||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
|
||||
|
||||
if (!shape_ok || !shmem_ok) {
|
||||
@@ -3094,20 +3100,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it.
|
||||
if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) {
|
||||
path = FA_COOPMAT2;
|
||||
}
|
||||
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
case FA_COOPMAT1:
|
||||
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc);
|
||||
case FA_COOPMAT2:
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
||||
default:
|
||||
throw std::runtime_error("unsupported FaCodePath");
|
||||
}
|
||||
}
|
||||
|
||||
static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
|
||||
bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
|
||||
bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) {
|
||||
const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
|
||||
(device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
|
||||
|
||||
@@ -3118,12 +3129,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const
|
||||
|
||||
const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
|
||||
|
||||
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
|
||||
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type};
|
||||
}
|
||||
|
||||
static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
|
||||
return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
|
||||
state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
|
||||
const auto fa_block_bytes = [](ggml_type t) -> uint32_t {
|
||||
// decodeBufF32 uses a block of vec4s for a better memory access pattern.
|
||||
return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t);
|
||||
};
|
||||
return {
|
||||
/* 0 WorkGroupSize */ state.workgroup_size,
|
||||
/* 1 Br */ state.Br,
|
||||
/* 2 Bc */ state.Bc,
|
||||
/* 3 HSK */ state.HSK,
|
||||
/* 4 HSV */ state.HSV,
|
||||
/* 5 Clamp */ static_cast<uint32_t>(!state.aligned),
|
||||
/* 6 D_split */ state.D_split,
|
||||
/* 7 row_split */ state.row_split,
|
||||
/* 8 SubGroupSize */ state.subgroup_size,
|
||||
/* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u,
|
||||
/*10 Flags */ state.flags,
|
||||
/*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem,
|
||||
/*12 FaTypeK */ static_cast<uint32_t>(state.k_type),
|
||||
/*13 FaTypeV */ static_cast<uint32_t>(state.v_type),
|
||||
/*14 FaBlockBytesK */ fa_block_bytes(state.k_type),
|
||||
/*15 FaBlockBytesV */ fa_block_bytes(state.v_type),
|
||||
};
|
||||
}
|
||||
|
||||
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
||||
@@ -3578,16 +3609,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
}
|
||||
#endif
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
#define CREATE_FA_CM2_MIXED() \
|
||||
for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \
|
||||
FaCodePath path = fa.first.path; \
|
||||
uint32_t Br = fa.first.Br; \
|
||||
uint32_t Bc = fa.first.Bc; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
if (path == FA_COOPMAT2) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
if (device->coopmat2) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA_CM2_MIXED();
|
||||
}
|
||||
#undef CREATE_FA_CM2_MIXED
|
||||
#endif
|
||||
#undef CREATE_FA
|
||||
|
||||
@@ -9042,8 +9092,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
assert(dst->type == GGML_TYPE_F32);
|
||||
assert(q->type == GGML_TYPE_F32);
|
||||
assert(k->type == v->type);
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
uint32_t workgroups_x = (uint32_t)neq1;
|
||||
@@ -9054,7 +9102,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
|
||||
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc);
|
||||
const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
|
||||
|
||||
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
@@ -9067,7 +9115,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
workgroups_y /= gqa_ratio;
|
||||
}
|
||||
|
||||
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
|
||||
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
|
||||
|
||||
if (tuning_params.path != FA_COOPMAT2) {
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
@@ -9106,7 +9158,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
||||
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
|
||||
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
|
||||
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
||||
mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type);
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
@@ -15590,38 +15642,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
// It's straightforward to support different K/V dequant, but would
|
||||
// significantly increase the number of pipelines
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
// mismatching K/V type is currently supported for coopmat2 only.
|
||||
if (op->src[1]->type != op->src[2]->type && !coopmat2) {
|
||||
return false;
|
||||
}
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
// supported in scalar and coopmat2 paths
|
||||
break;
|
||||
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
||||
//case GGML_TYPE_Q2_K:
|
||||
//case GGML_TYPE_Q3_K:
|
||||
//case GGML_TYPE_Q4_K:
|
||||
//case GGML_TYPE_Q5_K:
|
||||
//case GGML_TYPE_Q6_K:
|
||||
//case GGML_TYPE_IQ1_S:
|
||||
//case GGML_TYPE_IQ1_M:
|
||||
//case GGML_TYPE_IQ2_XXS:
|
||||
//case GGML_TYPE_IQ2_XS:
|
||||
//case GGML_TYPE_IQ2_S:
|
||||
//case GGML_TYPE_IQ3_XXS:
|
||||
//case GGML_TYPE_IQ3_S:
|
||||
//case GGML_TYPE_IQ4_XS:
|
||||
|
||||
default:
|
||||
auto fa_kv_ok = [coopmat2](ggml_type t) {
|
||||
switch (t) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
case GGML_TYPE_Q1_0:
|
||||
return coopmat2;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
|
||||
return false;
|
||||
}
|
||||
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
|
||||
|
||||
@@ -13,6 +13,12 @@ layout (constant_id = 8) const uint32_t SubGroupSize = 32;
|
||||
layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
|
||||
layout (constant_id = 10) const uint32_t Flags = 0;
|
||||
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
|
||||
// ggml_type enumerant for K/V
|
||||
layout (constant_id = 12) const uint32_t FaTypeK = 0;
|
||||
layout (constant_id = 13) const uint32_t FaTypeV = 0;
|
||||
// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4).
|
||||
layout (constant_id = 14) const uint32_t FaBlockBytesK = 2;
|
||||
layout (constant_id = 15) const uint32_t FaBlockBytesV = 2;
|
||||
|
||||
const bool USE_MASK_OPT = (Flags & 1) != 0;
|
||||
const bool MASK_ENABLE = (Flags & 2) != 0;
|
||||
|
||||
@@ -17,8 +17,57 @@
|
||||
#extension GL_EXT_null_initializer : enable
|
||||
|
||||
#include "types.glsl"
|
||||
#include "dequant_funcs_cm2.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
#include "dequant_funcs_cm2.glsl"
|
||||
|
||||
// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V.
|
||||
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K {
|
||||
uint8_t raw[FaBlockBytesK];
|
||||
};
|
||||
layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_V {
|
||||
uint8_t raw[FaBlockBytesV];
|
||||
};
|
||||
|
||||
uint fa_block_elems(uint ty) {
|
||||
switch (ty) {
|
||||
case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32)
|
||||
case 1u: return 1u; // GGML_TYPE_F16
|
||||
case 2u: return uint(QUANT_K_Q4_0);
|
||||
case 3u: return uint(QUANT_K_Q4_1);
|
||||
case 6u: return uint(QUANT_K_Q5_0);
|
||||
case 7u: return uint(QUANT_K_Q5_1);
|
||||
case 8u: return uint(QUANT_K_Q8_0);
|
||||
case 41u: return uint(QUANT_K_Q1_0);
|
||||
default:
|
||||
return 1u;
|
||||
}
|
||||
}
|
||||
|
||||
float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeK) {
|
||||
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
|
||||
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return float16_t(0);
|
||||
}
|
||||
}
|
||||
|
||||
float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeV) {
|
||||
case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock);
|
||||
case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return float16_t(0);
|
||||
}
|
||||
}
|
||||
|
||||
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
||||
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
||||
@@ -55,12 +104,6 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
||||
return max(elem0, elem1);
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
#define DECODEFUNC , DEQUANTFUNC
|
||||
#else
|
||||
#define DECODEFUNC
|
||||
#endif
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
@@ -95,10 +138,6 @@ ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c
|
||||
}
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
init_indices();
|
||||
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
@@ -107,10 +146,10 @@ void main() {
|
||||
|
||||
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
||||
|
||||
#if BLOCK_SIZE > 1
|
||||
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
|
||||
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
||||
#endif
|
||||
const uint bs_k = fa_block_elems(FaTypeK);
|
||||
const uint bs_v = fa_block_elems(FaTypeV);
|
||||
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, bs_k);
|
||||
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, bs_v);
|
||||
|
||||
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
|
||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
|
||||
@@ -120,10 +159,12 @@ void main() {
|
||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
{
|
||||
q_stride &= ~7;
|
||||
#if BLOCK_SIZE == 1
|
||||
k_stride &= ~7;
|
||||
v_stride &= ~7;
|
||||
#endif
|
||||
if (bs_k == 1u) {
|
||||
k_stride &= ~7;
|
||||
}
|
||||
if (bs_v == 1u) {
|
||||
v_stride &= ~7;
|
||||
}
|
||||
m_stride &= ~7;
|
||||
}
|
||||
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
||||
@@ -230,7 +271,13 @@ void main() {
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
|
||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
|
||||
// F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128.
|
||||
const bool k_use_decode = (bs_k > 1u);
|
||||
if (k_use_decode) {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK);
|
||||
} else {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
|
||||
}
|
||||
S = coopMatMulAdd(Qf16, K_T, S);
|
||||
|
||||
if (LOGIT_SOFTCAP) {
|
||||
@@ -291,7 +338,12 @@ void main() {
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
|
||||
const bool v_use_decode = (bs_v > 1u);
|
||||
if (v_use_decode) {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV);
|
||||
} else {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
|
||||
}
|
||||
|
||||
L = eM*L + rowsum;
|
||||
|
||||
|
||||
@@ -641,20 +641,17 @@ void process_shaders() {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
}
|
||||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
|
||||
} else {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
|
||||
@@ -228,11 +228,13 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
|
||||
/** Row Norm **/
|
||||
|
||||
struct ggml_webgpu_row_norm_pipeline_key {
|
||||
ggml_op op;
|
||||
bool inplace;
|
||||
ggml_op op;
|
||||
ggml_type src_type;
|
||||
ggml_type dst_type;
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
|
||||
return op == other.op && inplace == other.inplace;
|
||||
return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -240,6 +242,8 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.src_type);
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
@@ -1097,6 +1101,8 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_row_norm_pipeline_key key = {};
|
||||
key.op = context.dst->op;
|
||||
key.src_type = context.src0->type;
|
||||
key.dst_type = context.dst->type;
|
||||
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
||||
|
||||
auto it = row_norm_pipelines.find(key);
|
||||
@@ -1111,6 +1117,10 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("RMS_NORM");
|
||||
variant = "rms_norm";
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
defines.push_back("NORM");
|
||||
variant = "norm";
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
defines.push_back("L2_NORM");
|
||||
variant = "l2_norm";
|
||||
@@ -1124,6 +1134,22 @@ class ggml_webgpu_shader_lib {
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
if (key.src_type == GGML_TYPE_F32) {
|
||||
defines.push_back("SRC_F32");
|
||||
variant += "_src_f32";
|
||||
} else if (key.src_type == GGML_TYPE_F16) {
|
||||
defines.push_back("SRC_F16");
|
||||
variant += "_src_f16";
|
||||
}
|
||||
|
||||
if (key.dst_type == GGML_TYPE_F32) {
|
||||
defines.push_back("DST_F32");
|
||||
variant += "_dst_f32";
|
||||
} else if (key.dst_type == GGML_TYPE_F16) {
|
||||
defines.push_back("DST_F16");
|
||||
variant += "_dst_f16";
|
||||
}
|
||||
|
||||
const uint32_t row_norm_wg_size = 128u;
|
||||
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
@@ -1779,12 +1805,12 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_mul_mat_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_subgroup_matrix = context.supports_subgroup_matrix;
|
||||
|
||||
auto it = mul_mat_fast_pipelines.find(key);
|
||||
@@ -2143,6 +2169,9 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
// variant suffix for src1 type
|
||||
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
|
||||
|
||||
|
||||
@@ -2927,6 +2927,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
|
||||
} else {
|
||||
return ggml_webgpu_row_norm(ctx, src0, node);
|
||||
}
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_webgpu_row_norm(ctx, src0, node);
|
||||
case GGML_OP_ROPE:
|
||||
@@ -4071,6 +4072,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
#ifdef INPLACE
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
src[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
#if defined(SRC_F16) || defined(DST_F16)
|
||||
enable f16;
|
||||
#endif
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
#ifdef SRC_F16
|
||||
#define SRC_TYPE f16
|
||||
#else
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
dst[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
#define SRC_TYPE f32
|
||||
#endif
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#ifdef DST_F16
|
||||
#define DST_TYPE f16
|
||||
#else
|
||||
#define DST_TYPE f32
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
@@ -40,9 +37,20 @@ struct Params {
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
var<storage, read_write> src: array<SRC_TYPE>;
|
||||
|
||||
var<workgroup> scratch: array<f32, WG_SIZE>;
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
#else
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<DST_TYPE>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#endif
|
||||
|
||||
var<workgroup> scratch: array<f32, WG_SIZE * 2u>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@@ -65,34 +73,81 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
sum += pow(src[i_src_row + col], 2.0);
|
||||
let v = f32(src[i_src_row + col]);
|
||||
#ifdef NORM
|
||||
sum += v;
|
||||
#else
|
||||
sum += v * v;
|
||||
#endif
|
||||
col += WG_SIZE;
|
||||
}
|
||||
|
||||
scratch[lid.x] = sum;
|
||||
workgroupBarrier();
|
||||
var offset: u32 = WG_SIZE / 2;
|
||||
|
||||
var offset: u32 = WG_SIZE / 2u;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
}
|
||||
offset = offset / 2;
|
||||
offset /= 2u;
|
||||
workgroupBarrier();
|
||||
}
|
||||
sum = scratch[0];
|
||||
|
||||
#ifdef RMS_NORM
|
||||
#ifdef NORM
|
||||
let mean = sum / f32(params.ne0);
|
||||
var sq_sum = 0.0f;
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
let v = f32(src[i_src_row + col]);
|
||||
let d = v - mean;
|
||||
sq_sum += d * d;
|
||||
col += WG_SIZE;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
scratch[lid.x] = sq_sum;
|
||||
workgroupBarrier();
|
||||
offset = WG_SIZE / 2u;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
}
|
||||
offset /= 2u;
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let variance = scratch[0] / f32(params.ne0);
|
||||
let scale = 1.0 / sqrt(variance + params.eps);
|
||||
#elif defined(RMS_NORM)
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||
#elif defined(L2_NORM)
|
||||
let scale = 1.0/max(sqrt(sum), params.eps);
|
||||
#endif
|
||||
|
||||
#ifdef NORM
|
||||
let mean_val = mean;
|
||||
#else
|
||||
let mean_val = 0.0f;
|
||||
#endif
|
||||
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
update(i_src_row + col, i_dst_row + col, scale);
|
||||
let i_src = i_src_row + col;
|
||||
let i_dst = i_dst_row + col;
|
||||
let v = src[i_src];
|
||||
#ifdef INPLACE
|
||||
src[i_dst] = scale * (v - mean_val);
|
||||
#else
|
||||
dst[i_dst] = scale * (v - mean_val);
|
||||
#endif
|
||||
col += WG_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,8 +55,13 @@
|
||||
|
||||
uint64_t ggml_graph_next_uid(void) {
|
||||
#ifdef _MSC_VER
|
||||
#if defined(_WIN32)
|
||||
static volatile LONG counter = 1;
|
||||
return (uint64_t) InterlockedIncrement(&counter) - 1;
|
||||
#else
|
||||
static volatile long long counter = 1;
|
||||
return (uint64_t) _InterlockedIncrement64(&counter) - 1;
|
||||
#endif
|
||||
#else
|
||||
static uint64_t counter = 1;
|
||||
return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED);
|
||||
@@ -3259,6 +3264,16 @@ void ggml_mul_mat_set_prec(
|
||||
ggml_set_op_params_i32(a, 0, prec_i32);
|
||||
}
|
||||
|
||||
void ggml_mul_mat_set_hint(
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_op_hint hint) {
|
||||
GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
|
||||
|
||||
const int32_t hint_i32 = (int32_t) hint;
|
||||
|
||||
ggml_set_op_params_i32(a, 1, hint_i32);
|
||||
}
|
||||
|
||||
// ggml_mul_mat_id
|
||||
|
||||
/*
|
||||
|
||||
@@ -864,6 +864,9 @@ extern "C" {
|
||||
// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
|
||||
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
|
||||
|
||||
// keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load)
|
||||
#define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2
|
||||
|
||||
typedef uint32_t llama_state_seq_flags;
|
||||
|
||||
LLAMA_API size_t llama_state_seq_get_size_ext(
|
||||
|
||||
@@ -79,7 +79,7 @@ def print_info(msg):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def chat_completion(url, messages, tools=None, stream=False):
|
||||
def chat_completion(url, messages, tools=None, stream=False, force_tools=False):
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
@@ -87,7 +87,10 @@ def chat_completion(url, messages, tools=None, stream=False):
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
if force_tools:
|
||||
payload["tool_choice"] = "required"
|
||||
else:
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, stream=stream)
|
||||
@@ -160,7 +163,13 @@ def chat_completion(url, messages, tools=None, stream=False):
|
||||
return result
|
||||
|
||||
|
||||
def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6):
|
||||
def all_tools_called(tools, all_tool_calls):
|
||||
all_tool_names = set([tc["function"]["name"] for tc in tools])
|
||||
all_called_tool_names = set([tc["function"]["name"] for tc in all_tool_calls])
|
||||
return all_tool_names == all_called_tool_names
|
||||
|
||||
|
||||
def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6, force_tools=False):
|
||||
"""
|
||||
Drive the multi-turn tool-call loop:
|
||||
1. Send messages to model.
|
||||
@@ -172,8 +181,8 @@ def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turn
|
||||
msgs = list(messages)
|
||||
all_tool_calls: list[dict] = []
|
||||
|
||||
for _ in range(max_turns):
|
||||
result = chat_completion(url, msgs, tools=tools, stream=stream)
|
||||
for t in range(max_turns):
|
||||
result = chat_completion(url, msgs, tools=tools, stream=stream, force_tools=(force_tools and not all_tools_called(tools, all_tool_calls)))
|
||||
if result is None:
|
||||
return all_tool_calls, None
|
||||
|
||||
@@ -235,10 +244,10 @@ def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_test(url, test_case, stream):
|
||||
def run_test(url, test_case, stream, force_tools):
|
||||
name = test_case["name"]
|
||||
mode = f"{'stream' if stream else 'non-stream'}"
|
||||
print_header(f"{name} [{mode}]")
|
||||
print_header(f"{name} [{mode}, force_tools={force_tools}] ")
|
||||
|
||||
all_tool_calls, final_content = run_agentic_loop(
|
||||
url,
|
||||
@@ -246,6 +255,7 @@ def run_test(url, test_case, stream):
|
||||
tools=test_case["tools"],
|
||||
mock_tool_responses=test_case["mock_tool_responses"],
|
||||
stream=stream,
|
||||
force_tools=force_tools
|
||||
)
|
||||
|
||||
if final_content is None and not all_tool_calls:
|
||||
@@ -1093,6 +1103,9 @@ def main():
|
||||
parser.add_argument(
|
||||
"--stream-only", action="store_true", help="Only run streaming mode tests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-tools", action="store_true", help="Change tool mode to forced instead of auto"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
help="Run only the test whose name contains this substring (case-insensitive)",
|
||||
@@ -1103,10 +1116,13 @@ def main():
|
||||
print_info(f"Testing server at {url}")
|
||||
|
||||
modes = []
|
||||
force_tools = False
|
||||
if not args.stream_only:
|
||||
modes.append(False)
|
||||
if not args.no_stream:
|
||||
modes.append(True)
|
||||
if args.force_tools:
|
||||
force_tools = True
|
||||
|
||||
cases: list[dict] = ALL_TEST_CASES
|
||||
if args.test:
|
||||
@@ -1121,7 +1137,7 @@ def main():
|
||||
for stream in modes:
|
||||
for case in cases:
|
||||
total += 1
|
||||
if run_test(url, case, stream=stream):
|
||||
if run_test(url, case, stream=stream, force_tools=force_tools):
|
||||
passed += 1
|
||||
|
||||
color = GREEN if passed == total else RED
|
||||
|
||||
@@ -1 +1 @@
|
||||
387fa29fbbf3149f06a631c7850b6c35c24b0232
|
||||
ac6f7b44f60fde0091f0b3d99afde48f8c99b13a
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "refs/tags/v0.43.2"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.43.3"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
||||
+275
-27
@@ -2230,13 +2230,17 @@ llm_graph_cb llama_context::graph_get_cb() const {
|
||||
|
||||
class llama_io_write_dummy : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_dummy() = default;
|
||||
llama_io_write_dummy(bool skip_tensors) : skip_tensors(skip_tensors) {}
|
||||
|
||||
void write(const void * /* src */, size_t size) override {
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
||||
void write_tensor(ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
||||
if (skip_tensors) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
@@ -2245,14 +2249,23 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
const bool skip_tensors;
|
||||
|
||||
size_t size_written = 0;
|
||||
};
|
||||
|
||||
class llama_io_write_buffer : public llama_io_write_i {
|
||||
class llama_io_write_host : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_buffer(
|
||||
llama_io_write_host(
|
||||
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
~llama_io_write_host() {
|
||||
// TODO: add backend support to batch tensor_get? or some other way to speed this up
|
||||
for (const auto & winfo : winfos) {
|
||||
ggml_backend_tensor_get(winfo.tensor, winfo.ptr, winfo.offset, winfo.size);
|
||||
}
|
||||
}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
@@ -2263,11 +2276,14 @@ public:
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
ggml_backend_tensor_get(tensor, ptr, offset, size);
|
||||
|
||||
// save the write for later during destruction
|
||||
winfos.push_back({tensor, ptr, size, offset});
|
||||
|
||||
ptr += size;
|
||||
size_written += size;
|
||||
buf_size -= size;
|
||||
@@ -2281,25 +2297,48 @@ private:
|
||||
uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_written = 0;
|
||||
|
||||
struct write_info {
|
||||
ggml_tensor * tensor;
|
||||
uint8_t * ptr;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
std::vector<write_info> winfos;
|
||||
};
|
||||
|
||||
class llama_io_read_buffer : public llama_io_read_i {
|
||||
class llama_io_read_host : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
llama_io_read_host(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
const uint8_t * base_ptr = ptr;
|
||||
~llama_io_read_host() {
|
||||
// flush the reads
|
||||
for (const auto & rinfo : rinfos) {
|
||||
ggml_backend_tensor_set(rinfo.tensor, rinfo.ptr, rinfo.offset, rinfo.size);
|
||||
}
|
||||
}
|
||||
|
||||
void read(void * dst, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
memcpy(dst, ptr, size);
|
||||
ptr += size;
|
||||
size_read += size;
|
||||
buf_size -= size;
|
||||
return base_ptr;
|
||||
}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
memcpy(dst, read(size), size);
|
||||
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
|
||||
// save for later during destruction
|
||||
rinfos.push_back({tensor, ptr, size, offset});
|
||||
|
||||
ptr += size;
|
||||
size_read += size;
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
@@ -2310,6 +2349,14 @@ private:
|
||||
const uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_read = 0;
|
||||
|
||||
struct read_info {
|
||||
ggml_tensor * tensor;
|
||||
const uint8_t * ptr;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
std::vector<read_info> rinfos;
|
||||
};
|
||||
|
||||
class llama_io_write_file : public llama_io_write_i {
|
||||
@@ -2321,7 +2368,7 @@ public:
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
temp_buffer.resize(size);
|
||||
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
|
||||
write(temp_buffer.data(), temp_buffer.size());
|
||||
@@ -2341,15 +2388,15 @@ class llama_io_read_file : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_file(llama_file * f) : file(f) {}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
void read(void * dst, size_t size) override {
|
||||
file->read_raw(dst, size);
|
||||
size_read += size;
|
||||
}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
temp_buffer.resize(size);
|
||||
read_to(temp_buffer.data(), size);
|
||||
return temp_buffer.data();
|
||||
read(temp_buffer.data(), size);
|
||||
ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size);
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
@@ -2362,8 +2409,162 @@ private:
|
||||
std::vector<uint8_t> temp_buffer;
|
||||
};
|
||||
|
||||
class llama_io_write_device : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_device(uint8_t * p, size_t len, llama_memory_buffers & mbufs) : ptr(p), buf_size(len), mbufs(mbufs) {
|
||||
}
|
||||
|
||||
~llama_io_write_device() {
|
||||
llama_memory_buffers mbufs_new;
|
||||
|
||||
for (const auto & winfo : winfos) {
|
||||
auto * buft = ggml_backend_buffer_get_type(winfo.tensor->buffer);
|
||||
|
||||
mbufs_new[buft].n_tensors++;
|
||||
mbufs_new[buft].total_size += winfo.size;
|
||||
}
|
||||
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ 2*mbuf.n_tensors*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
mbuf.ctx.reset(ggml_init(params));
|
||||
|
||||
mbuf.org.reserve(mbuf.n_tensors);
|
||||
mbuf.cpy.reserve(mbuf.n_tensors);
|
||||
}
|
||||
|
||||
for (const auto & winfo : winfos) {
|
||||
auto * buft = ggml_backend_buffer_get_type(winfo.tensor->buffer);
|
||||
|
||||
const int64_t n = winfo.size/ggml_element_size(winfo.tensor);
|
||||
|
||||
auto & mbuf = mbufs_new[buft];
|
||||
|
||||
mbuf.org.push_back(ggml_view_1d (mbuf.ctx.get(), winfo.tensor, n, winfo.offset));
|
||||
mbuf.cpy.push_back(ggml_new_tensor_1d(mbuf.ctx.get(), winfo.tensor->type, n));
|
||||
}
|
||||
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
auto & mbuf_cur = mbufs[buft];
|
||||
|
||||
if (!mbuf_cur.buf || mbuf_cur.org.size() != mbuf.org.size() || mbuf_cur.total_size != mbuf.total_size) {
|
||||
mbuf_cur = std::move(mbuf);
|
||||
|
||||
mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft));
|
||||
|
||||
LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
|
||||
ggml_backend_tensor_copy(mbuf_cur.org[i], mbuf_cur.cpy[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
memcpy(ptr, src, size);
|
||||
ptr += size;
|
||||
size_written += size;
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
// save the write for later during destruction
|
||||
winfos.push_back({tensor, ptr, size, offset});
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
private:
|
||||
uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_written = 0;
|
||||
|
||||
struct write_info {
|
||||
ggml_tensor * tensor;
|
||||
uint8_t * ptr;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
std::vector<write_info> winfos;
|
||||
|
||||
llama_memory_buffers & mbufs;
|
||||
};
|
||||
|
||||
class llama_io_read_device : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_device(const uint8_t * p, size_t len, const llama_memory_buffers & mbufs) : ptr(p), buf_size(len), mbufs(mbufs) {
|
||||
}
|
||||
|
||||
~llama_io_read_device() {
|
||||
llama_memory_buffers mbufs_new;
|
||||
|
||||
for (const auto & rinfo : rinfos) {
|
||||
auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer);
|
||||
|
||||
mbufs_new[buft].n_tensors++;
|
||||
mbufs_new[buft].total_size += rinfo.size;
|
||||
}
|
||||
|
||||
for (auto & [buft, mbuf] : mbufs_new) {
|
||||
const auto & mbuf_cur = mbufs.at(buft);
|
||||
|
||||
if (!mbuf_cur.buf || mbuf_cur.n_tensors != mbuf.n_tensors || mbuf_cur.total_size != mbuf.total_size) {
|
||||
GGML_ABORT("%s: memory buffer mismatch\n", __func__);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < mbuf_cur.org.size(); ++i) {
|
||||
ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf_cur.org[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void read(void * dst, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
memcpy(dst, ptr, size);
|
||||
ptr += size;
|
||||
size_read += size;
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
// save for later during destruction
|
||||
rinfos.push_back({tensor, ptr, size, offset});
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_read;
|
||||
}
|
||||
|
||||
private:
|
||||
const uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_read = 0;
|
||||
|
||||
struct read_info {
|
||||
ggml_tensor * tensor;
|
||||
const uint8_t * ptr;
|
||||
size_t size;
|
||||
size_t offset;
|
||||
};
|
||||
std::vector<read_info> rinfos;
|
||||
|
||||
const llama_memory_buffers & mbufs;
|
||||
};
|
||||
|
||||
size_t llama_context::state_get_size() {
|
||||
llama_io_write_dummy io;
|
||||
llama_io_write_dummy io(false);
|
||||
try {
|
||||
return state_write_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
@@ -2373,7 +2574,7 @@ size_t llama_context::state_get_size() {
|
||||
}
|
||||
|
||||
size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
|
||||
llama_io_write_buffer io(dst, size);
|
||||
llama_io_write_host io(dst, size);
|
||||
try {
|
||||
return state_write_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
@@ -2383,7 +2584,7 @@ size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
|
||||
}
|
||||
|
||||
size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
||||
llama_io_read_buffer io(src, size);
|
||||
llama_io_read_host io(src, size);
|
||||
try {
|
||||
return state_read_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
@@ -2392,9 +2593,14 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr uint32_t io_magic = 0xaf143cd8;
|
||||
|
||||
size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
llama_io_write_dummy io;
|
||||
llama_io_write_dummy io(flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
try {
|
||||
io.write(&io_magic, sizeof(io_magic));
|
||||
io.write(&seq_id, sizeof(seq_id));
|
||||
|
||||
return state_seq_write_data(io, seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||
@@ -2403,9 +2609,18 @@ size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_fl
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
|
||||
llama_io_write_buffer io(dst, size);
|
||||
std::unique_ptr<llama_io_write_i> io;
|
||||
if (flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) {
|
||||
io = std::make_unique<llama_io_write_device>(dst, size, mem_storage[seq_id]);
|
||||
} else {
|
||||
io = std::make_unique<llama_io_write_host>(dst, size);
|
||||
}
|
||||
|
||||
try {
|
||||
return state_seq_write_data(io, seq_id, flags);
|
||||
io->write(&io_magic, sizeof(io_magic));
|
||||
io->write(&seq_id, sizeof(seq_id));
|
||||
|
||||
return state_seq_write_data(*io, seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@@ -2413,9 +2628,43 @@ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, siz
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
|
||||
llama_io_read_buffer io(src, size);
|
||||
std::unique_ptr<llama_io_read_i> io;
|
||||
if (flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) {
|
||||
// create a temporary io to read the magic and the src seq_id
|
||||
io = std::make_unique<llama_io_read_host>(src, size);
|
||||
|
||||
uint32_t magic_read;
|
||||
io->read(&magic_read, sizeof(magic_read));
|
||||
if (io_magic != magic_read) {
|
||||
throw std::runtime_error("wrong sequence state magic");
|
||||
}
|
||||
|
||||
llama_seq_id seq_id_read;
|
||||
io->read(&seq_id_read, sizeof(seq_id_read));
|
||||
|
||||
GGML_ASSERT(mem_storage.find(seq_id_read) != mem_storage.end());
|
||||
|
||||
io = std::make_unique<llama_io_read_device>(src, size, mem_storage[seq_id_read]);
|
||||
} else {
|
||||
io = std::make_unique<llama_io_read_host>(src, size);
|
||||
}
|
||||
|
||||
try {
|
||||
return state_seq_read_data(io, seq_id, flags);
|
||||
uint32_t magic_read;
|
||||
io->read(&magic_read, sizeof(magic_read));
|
||||
if (io_magic != magic_read) {
|
||||
throw std::runtime_error("wrong sequence state magic");
|
||||
}
|
||||
|
||||
const bool need_seq_match = (flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
llama_seq_id seq_id_read;
|
||||
io->read(&seq_id_read, sizeof(seq_id_read));
|
||||
if (need_seq_match && seq_id != seq_id_read) {
|
||||
throw std::runtime_error("wrong sequence id");
|
||||
}
|
||||
|
||||
return state_seq_read_data(*io, seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@@ -3406,7 +3655,6 @@ size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t s
|
||||
|
||||
return ctx->state_seq_get_data(seq_id, dst, size, flags);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
ctx->synchronize();
|
||||
|
||||
|
||||
@@ -23,6 +23,21 @@ class llama_io_write_i;
|
||||
struct llama_memory_i;
|
||||
struct llama_memory_context_i;
|
||||
|
||||
// stores copy of the memory in device buffer. used for fast state save/load
|
||||
struct llama_memory_buffer {
|
||||
int n_tensors = 0;
|
||||
size_t total_size = 0;
|
||||
|
||||
ggml_backend_buffer_ptr buf;
|
||||
|
||||
ggml_context_ptr ctx;
|
||||
|
||||
std::vector<ggml_tensor *> org;
|
||||
std::vector<ggml_tensor *> cpy;
|
||||
};
|
||||
|
||||
using llama_memory_buffers = std::map<ggml_backend_buffer_type_t, llama_memory_buffer>;
|
||||
|
||||
struct llama_context {
|
||||
// init scheduler and compute buffers, reserve worst-case graphs
|
||||
llama_context(
|
||||
@@ -128,6 +143,7 @@ struct llama_context {
|
||||
size_t state_set_data(const uint8_t * src, size_t size);
|
||||
|
||||
size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
|
||||
|
||||
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
|
||||
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
|
||||
|
||||
@@ -328,6 +344,9 @@ private:
|
||||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_ptr buf_output;
|
||||
|
||||
// keep copies of the per-sequence memory on the device
|
||||
std::map<llama_seq_id, llama_memory_buffers> mem_storage;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
// env: LLAMA_GRAPH_REUSE_DISABLE
|
||||
|
||||
+6
-1
@@ -65,8 +65,13 @@ static ggml_tensor * ggml_mul_mat_aux(
|
||||
|
||||
ggml_tensor * res;
|
||||
|
||||
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
|
||||
if (!ggml_is_contiguous(cur)) {
|
||||
res = ggml_cont_2d (ctx, cur, n, ggml_nelements(cur)/n);
|
||||
} else {
|
||||
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
|
||||
}
|
||||
res = ggml_mul_mat (ctx, rot, res);
|
||||
ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD);
|
||||
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
|
||||
|
||||
return res;
|
||||
|
||||
+7
-2
@@ -1,5 +1,7 @@
|
||||
#include "llama-io.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
void llama_io_write_i::write_string(const std::string & str) {
|
||||
uint32_t str_size = str.size();
|
||||
|
||||
@@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) {
|
||||
|
||||
void llama_io_read_i::read_string(std::string & str) {
|
||||
uint32_t str_size;
|
||||
read_to(&str_size, sizeof(str_size));
|
||||
read(&str_size, sizeof(str_size));
|
||||
|
||||
str.assign((const char *) read(str_size), str_size);
|
||||
std::vector<char> buf(str_size);
|
||||
read(buf.data(), str_size);
|
||||
|
||||
str.assign(buf.data(), str_size);
|
||||
}
|
||||
|
||||
+3
-3
@@ -12,7 +12,7 @@ public:
|
||||
virtual ~llama_io_write_i() = default;
|
||||
|
||||
virtual void write(const void * src, size_t size) = 0;
|
||||
virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
||||
virtual void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
||||
|
||||
// bytes written so far
|
||||
virtual size_t n_bytes() = 0;
|
||||
@@ -25,8 +25,8 @@ public:
|
||||
llama_io_read_i() = default;
|
||||
virtual ~llama_io_read_i() = default;
|
||||
|
||||
virtual const uint8_t * read(size_t size) = 0;
|
||||
virtual void read_to(void * dst, size_t size) = 0;
|
||||
virtual void read(void * dst, size_t size) = 0;
|
||||
virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
||||
|
||||
// bytes read so far
|
||||
virtual size_t n_bytes() = 0;
|
||||
|
||||
+26
-28
@@ -67,6 +67,7 @@ static ggml_tensor * ggml_mul_mat_aux(
|
||||
|
||||
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
|
||||
res = ggml_mul_mat (ctx, rot, res);
|
||||
ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD);
|
||||
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
|
||||
|
||||
return res;
|
||||
@@ -1900,14 +1901,14 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
||||
uint32_t n_stream_cur;
|
||||
io.read_to(&n_stream_cur, sizeof(n_stream_cur));
|
||||
io.read(&n_stream_cur, sizeof(n_stream_cur));
|
||||
if (n_stream_cur != n_stream) {
|
||||
throw std::runtime_error("n_stream mismatch");
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
uint32_t cell_count;
|
||||
io.read_to(&cell_count, sizeof(cell_count));
|
||||
io.read(&cell_count, sizeof(cell_count));
|
||||
|
||||
if (cell_count == 0) {
|
||||
continue;
|
||||
@@ -2082,8 +2083,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 1) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
@@ -2092,7 +2093,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
if (hparams.n_pos_per_embd() > 1) {
|
||||
llama_kv_cell_ext ext;
|
||||
io.read_to(&ext, sizeof(ext));
|
||||
io.read(&ext, sizeof(ext));
|
||||
|
||||
ubatch.pos[i + ubatch.n_tokens] = ext.y;
|
||||
ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
|
||||
@@ -2101,7 +2102,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
||||
{
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
io.read(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
|
||||
ubatch.pos[i] = pos;
|
||||
@@ -2143,20 +2144,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
cells.pos_set(i, pos);
|
||||
|
||||
if (hparams.n_pos_per_embd() > 1) {
|
||||
llama_kv_cell_ext ext;
|
||||
io.read_to(&ext, sizeof(ext));
|
||||
io.read(&ext, sizeof(ext));
|
||||
cells.ext_set(i, ext);
|
||||
}
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
io.read(&seq_id, sizeof(seq_id));
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
|
||||
@@ -2189,8 +2190,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
uint32_t v_trans;
|
||||
uint32_t n_layer;
|
||||
|
||||
io.read_to(&v_trans, sizeof(v_trans));
|
||||
io.read_to(&n_layer, sizeof(n_layer));
|
||||
io.read(&v_trans, sizeof(v_trans));
|
||||
io.read(&n_layer, sizeof(n_layer));
|
||||
|
||||
if (n_layer != layers.size()) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
||||
@@ -2217,7 +2218,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
||||
io.read(&k_type_i_ref, sizeof(k_type_i_ref));
|
||||
const int32_t k_type_i = (int32_t) k->type;
|
||||
if (k_type_i != k_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
||||
@@ -2226,7 +2227,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read row size of key
|
||||
uint64_t k_size_row_ref;
|
||||
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
||||
io.read(&k_size_row_ref, sizeof(k_size_row_ref));
|
||||
const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
|
||||
if (k_size_row != k_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
||||
@@ -2236,13 +2237,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
if (cell_count) {
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
|
||||
io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * k_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
|
||||
ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
|
||||
io.read_tensor(k, dst_offset, k_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2261,7 +2261,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t) v->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
@@ -2270,7 +2270,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read row size of value
|
||||
uint64_t v_size_row_ref;
|
||||
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
||||
io.read(&v_size_row_ref, sizeof(v_size_row_ref));
|
||||
const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
|
||||
if (v_size_row != v_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
||||
@@ -2280,13 +2280,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
if (cell_count) {
|
||||
if (sinfo.is_contiguous()) {
|
||||
// Fast path: contiguous cells, single memcpy
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
|
||||
io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row);
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
const void * src = io.read(cell_count * v_size_row);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
|
||||
io.read_tensor(v, dst_offset, v_size_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2305,7 +2304,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t) v->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
@@ -2314,7 +2313,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read element size of value
|
||||
uint32_t v_size_el_ref;
|
||||
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
||||
io.read(&v_size_el_ref, sizeof(v_size_el_ref));
|
||||
const size_t v_size_el = ggml_type_size(v->type);
|
||||
if (v_size_el != v_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
||||
@@ -2323,7 +2322,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
|
||||
// Read GQA embedding size
|
||||
uint32_t n_embd_v_gqa_ref;
|
||||
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||
io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
||||
return false;
|
||||
@@ -2335,15 +2334,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
|
||||
const uint32_t h = sinfo.head();
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (h + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
io.read_tensor(v, dst_offset, cell_count * v_size_el);
|
||||
}
|
||||
} else {
|
||||
// Slow path: scatter to non-contiguous positions
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const void * src = io.read(cell_count * v_size_el);
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
|
||||
io.read_tensor(v, dst_offset, v_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -743,7 +743,7 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
uint32_t cell_count;
|
||||
io.read_to(&cell_count, sizeof(cell_count));
|
||||
io.read(&cell_count, sizeof(cell_count));
|
||||
|
||||
bool res = true;
|
||||
|
||||
@@ -784,7 +784,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
io.write(&s_trans, sizeof(s_trans));
|
||||
io.write(&n_layer, sizeof(n_layer));
|
||||
io.write(&n_layer, sizeof(n_layer));
|
||||
|
||||
// Iterate and write all the R tensors first, each row is a cell
|
||||
// Get whole range at a time
|
||||
@@ -879,8 +879,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 0) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
@@ -920,14 +920,14 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
io.read(&pos, sizeof(pos));
|
||||
io.read(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
cell.pos = pos;
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
io.read(&seq_id, sizeof(seq_id));
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max);
|
||||
@@ -961,8 +961,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
uint32_t s_trans;
|
||||
uint32_t n_layer;
|
||||
io.read_to(&s_trans, sizeof(s_trans));
|
||||
io.read_to(&n_layer, sizeof(n_layer));
|
||||
io.read(&s_trans, sizeof(s_trans));
|
||||
io.read(&n_layer, sizeof(n_layer));
|
||||
|
||||
if (n_layer != hparams.n_layer) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
||||
@@ -984,7 +984,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read type of key
|
||||
int32_t r_type_i_ref;
|
||||
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
io.read(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
||||
if (r_type_i != r_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
||||
@@ -993,7 +993,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read row size of key
|
||||
uint64_t r_size_row_ref;
|
||||
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||
io.read(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
if (r_size_row != r_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
||||
@@ -1002,7 +1002,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
|
||||
io.read_tensor(r_l[il], head * r_size_row, cell_count * r_size_row);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1013,7 +1013,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read type of value
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
@@ -1023,7 +1023,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read row size of value
|
||||
uint64_t s_size_row_ref;
|
||||
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
io.read(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
if (s_size_row != s_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||
@@ -1032,7 +1032,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the values for the whole cell range
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
|
||||
io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -1045,7 +1045,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read type of value
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||
@@ -1054,7 +1054,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read element size of value
|
||||
uint32_t s_size_el_ref;
|
||||
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
|
||||
io.read(&s_size_el_ref, sizeof(s_size_el_ref));
|
||||
const size_t s_size_el = ggml_type_size(s_l[il]->type);
|
||||
if (s_size_el != s_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
|
||||
@@ -1063,7 +1063,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// Read state embedding size
|
||||
uint32_t n_embd_s_ref;
|
||||
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
||||
io.read(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
||||
if (n_embd_s != n_embd_s_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
|
||||
return false;
|
||||
@@ -1073,7 +1073,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||
const size_t dst_offset = (head + j * size) * s_size_el;
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
|
||||
io.read_tensor(s_l[il], dst_offset, cell_count * s_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+345
-7317
File diff suppressed because it is too large
Load Diff
+78
-11
@@ -577,14 +577,8 @@ struct llama_model {
|
||||
int64_t t_load_us = 0;
|
||||
int64_t t_start_us = 0;
|
||||
|
||||
explicit llama_model(const struct llama_model_params & params);
|
||||
~llama_model();
|
||||
|
||||
void load_stats (llama_model_loader & ml);
|
||||
void load_arch (llama_model_loader & ml);
|
||||
void load_hparams(llama_model_loader & ml);
|
||||
void load_vocab (llama_model_loader & ml);
|
||||
bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
|
||||
explicit llama_model(const llama_model_params & params);
|
||||
virtual ~llama_model();
|
||||
|
||||
std::string arch_name() const;
|
||||
std::string type_name() const;
|
||||
@@ -620,21 +614,94 @@ struct llama_model {
|
||||
|
||||
ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
|
||||
|
||||
// TODO: move this to new llm_arch_model_i interface
|
||||
llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;
|
||||
|
||||
// TODO: move this to new llm_arch_model_i interface
|
||||
ggml_cgraph * build_graph(const llm_graph_params & params) const;
|
||||
|
||||
private:
|
||||
virtual void load_stats (llama_model_loader & ml) = 0;
|
||||
virtual void load_hparams(llama_model_loader & ml) = 0;
|
||||
virtual void load_vocab (llama_model_loader & ml) = 0;
|
||||
virtual bool load_tensors(llama_model_loader & ml) = 0; // returns false if cancelled by progress_callback
|
||||
|
||||
// model must define these
|
||||
virtual void load_arch_hparams(llama_model_loader & ml) = 0;
|
||||
virtual void load_arch_tensors(llama_model_loader & ml) = 0;
|
||||
virtual std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const = 0;
|
||||
|
||||
protected:
|
||||
llama_model_params params;
|
||||
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
llama_model * llama_model_create(llm_arch arch, const llama_model_params & params);
|
||||
llama_model * llama_model_create(llama_model_loader & ml, const llama_model_params & params);
|
||||
|
||||
// model must inherit from this
|
||||
struct llama_model_base : public llama_model {
|
||||
friend struct llama_model;
|
||||
|
||||
llama_model * model;
|
||||
llama_model_loader * ml = nullptr;
|
||||
const LLM_TN tn;
|
||||
|
||||
// llama_model_loader is not yet defined at this point, so we will set it after construction
|
||||
const int TENSOR_DUPLICATED;
|
||||
const int TENSOR_NOT_REQUIRED;
|
||||
const int TENSOR_SKIP;
|
||||
const int TENSOR_SKIP_IF_VIRTUAL;
|
||||
|
||||
explicit llama_model_base(const llama_model_params & params);
|
||||
virtual ~llama_model_base() = default;
|
||||
|
||||
ggml_tensor * create_tensor(llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags);
|
||||
|
||||
// convenience overload of create_tensor that doesn't require llama_model_loader
|
||||
ggml_tensor * create_tensor(const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags);
|
||||
|
||||
// helper: try merged gate_up_exps first, fall back to separate gate and up
|
||||
void create_tensor_gate_up_exps(llama_layer & layer, int bid, int64_t n_embd_,
|
||||
int64_t n_ff_, int64_t n_expert_, int flags);
|
||||
|
||||
// helper: try to load merged qkv first, fall back to separate q, k, v
|
||||
void create_tensor_qkv(llama_layer & layer, int bid,
|
||||
int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_,
|
||||
int flags);
|
||||
|
||||
void load_stats (llama_model_loader & ml) override;
|
||||
void load_hparams(llama_model_loader & ml) override;
|
||||
void load_vocab (llama_model_loader & ml) override;
|
||||
bool load_tensors(llama_model_loader & ml) override;
|
||||
|
||||
// model must define these
|
||||
void load_arch_hparams(llama_model_loader & ml) override = 0;
|
||||
void load_arch_tensors(llama_model_loader & ml) override = 0;
|
||||
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override = 0;
|
||||
};
|
||||
|
||||
const char * llm_type_name(llm_type type);
|
||||
|
||||
// convenience macro for loading local variables for load_tensors() in llama_model_base
|
||||
// note: cast to int64_t since we will use these for the tensor dimensions
|
||||
#define LLAMA_LOAD_LOCALS \
|
||||
const int n_layer = hparams.n_layer; GGML_UNUSED(n_layer); \
|
||||
const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \
|
||||
const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \
|
||||
const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); GGML_UNUSED(n_embd_k_gqa); \
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); GGML_UNUSED(n_embd_v_gqa); \
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k(); GGML_UNUSED(n_embd_head_k); \
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v(); GGML_UNUSED(n_embd_head_v); \
|
||||
const int64_t n_ff = hparams.n_ff(); GGML_UNUSED(n_ff); \
|
||||
const int64_t n_embd_gqa = n_embd_v_gqa; GGML_UNUSED(n_embd_gqa); \
|
||||
const int64_t n_vocab = vocab.n_tokens(); GGML_UNUSED(n_vocab); \
|
||||
const int64_t n_token_types = vocab.n_token_types(); GGML_UNUSED(n_token_types); \
|
||||
const int64_t n_rot = hparams.n_rot(); GGML_UNUSED(n_rot); \
|
||||
const int64_t n_expert = hparams.n_expert; GGML_UNUSED(n_expert); \
|
||||
const int64_t n_expert_used = hparams.n_expert_used; GGML_UNUSED(n_expert_used); \
|
||||
const int64_t n_ctx_train = hparams.n_ctx_train; GGML_UNUSED(n_ctx_train);
|
||||
|
||||
// For internal test use
|
||||
// TODO: remove
|
||||
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);
|
||||
|
||||
+16
-11
@@ -683,9 +683,9 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_mod
|
||||
LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n",
|
||||
__func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype));
|
||||
new_type = qtype;
|
||||
manual = true;
|
||||
break;
|
||||
}
|
||||
manual = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -882,13 +882,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
|
||||
ml.init_mappings(false); // no prefetching
|
||||
|
||||
llama_model model(llama_model_default_params());
|
||||
auto mparams = llama_model_default_params();
|
||||
std::unique_ptr<llama_model> model_ptr(llama_model_create(ml, mparams));
|
||||
|
||||
model.load_arch (ml);
|
||||
model.load_hparams(ml);
|
||||
model.load_stats (ml);
|
||||
auto * model = dynamic_cast<llama_model_base *>(model_ptr.get());
|
||||
if (model == nullptr) {
|
||||
GGML_ABORT("fatal error: model does not implement llama_model_base");
|
||||
}
|
||||
|
||||
quantize_state_impl qs(model, params);
|
||||
model->load_hparams(ml);
|
||||
model->load_stats (ml);
|
||||
|
||||
quantize_state_impl qs(*model, params);
|
||||
|
||||
if (params->only_copy) {
|
||||
ftype = ml.ftype;
|
||||
@@ -1023,7 +1028,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
}
|
||||
gguf_add_tensor(ctx_outs[i_split].get(), tensor);
|
||||
|
||||
metadata[i].allows_quantization = tensor_allows_quantization(params, model.arch, tensor);
|
||||
metadata[i].allows_quantization = tensor_allows_quantization(params, model->arch, tensor);
|
||||
|
||||
if (metadata[i].allows_quantization) {
|
||||
metadata[i].target_type = llama_tensor_get_type(qs, params, tensor, default_type, metadata[i]);
|
||||
@@ -1331,9 +1336,9 @@ void llama_quant_free(quantize_state_impl * qs) {
|
||||
|
||||
llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) {
|
||||
struct llama_model_params mparams = llama_model_default_params();
|
||||
auto * model = new llama_model(mparams);
|
||||
|
||||
model->arch = llm_arch_from_string(desc->architecture);
|
||||
auto arch = llm_arch_from_string(desc->architecture);
|
||||
auto * model = llama_model_create(arch, mparams);
|
||||
model->arch = arch;
|
||||
|
||||
// infer llm_type: only LLM_TYPE_70B matters for quantization logic
|
||||
if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) {
|
||||
|
||||
+127
-113
@@ -89,6 +89,10 @@ void llama_backend_init(void) {
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
if (!ggml_backend_reg_count()) {
|
||||
ggml_backend_load_all();
|
||||
}
|
||||
}
|
||||
|
||||
void llama_numa_init(enum ggml_numa_strategy numa) {
|
||||
@@ -111,113 +115,8 @@ int64_t llama_time_us(void) {
|
||||
return ggml_time_us();
|
||||
}
|
||||
|
||||
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
||||
static int llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud,
|
||||
const std::string & fname, std::vector<std::string> & splits, FILE * file, llama_model & model, llama_model_params & params) {
|
||||
// loading time will be recalculated after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
model.t_load_us = 0;
|
||||
time_meas tm(model.t_load_us);
|
||||
|
||||
model.t_start_us = tm.t_start_us;
|
||||
|
||||
try {
|
||||
llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io,
|
||||
params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
|
||||
|
||||
ml.print_info();
|
||||
|
||||
model.hparams.vocab_only = params.vocab_only;
|
||||
model.hparams.no_alloc = params.no_alloc;
|
||||
|
||||
try {
|
||||
model.load_arch(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
|
||||
}
|
||||
try {
|
||||
model.load_hparams(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
||||
}
|
||||
if (model.arch == LLM_ARCH_CLIP) {
|
||||
throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
|
||||
}
|
||||
try {
|
||||
model.load_vocab(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
model.load_stats(ml);
|
||||
model.print_info();
|
||||
|
||||
if (params.vocab_only) {
|
||||
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!model.load_tensors(ml)) {
|
||||
return -2;
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static struct llama_model * llama_model_load_from_file_impl(
|
||||
struct gguf_context * metadata,
|
||||
llama_model_set_tensor_data_t set_tensor_data,
|
||||
void * set_tensor_data_ud,
|
||||
const std::string & path_model,
|
||||
std::vector<std::string> & splits,
|
||||
FILE * file,
|
||||
struct llama_model_params params) {
|
||||
{
|
||||
int n_sources_defined = 0;
|
||||
if (metadata != nullptr) {
|
||||
n_sources_defined++;
|
||||
}
|
||||
if (!path_model.empty()) {
|
||||
n_sources_defined++;
|
||||
}
|
||||
if (file != nullptr) {
|
||||
n_sources_defined++;
|
||||
}
|
||||
if (n_sources_defined != 1) {
|
||||
LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
ggml_time_init();
|
||||
|
||||
if (!params.vocab_only && ggml_backend_reg_count() == 0) {
|
||||
LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned cur_percentage = 0;
|
||||
if (params.progress_callback == NULL) {
|
||||
params.progress_callback_user_data = &cur_percentage;
|
||||
params.progress_callback = [](float progress, void * ctx) {
|
||||
unsigned * cur_percentage_p = (unsigned *) ctx;
|
||||
unsigned percentage = (unsigned) (100 * progress);
|
||||
while (percentage > *cur_percentage_p) {
|
||||
*cur_percentage_p = percentage;
|
||||
LLAMA_LOG_CONT(".");
|
||||
if (percentage >= 100) {
|
||||
LLAMA_LOG_CONT("\n");
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
}
|
||||
|
||||
llama_model * model = new llama_model(params);
|
||||
|
||||
// returns true on success
|
||||
static bool llama_prepare_model_devices(const llama_model_params & params, llama_model * model) {
|
||||
// create list of devices to use with this model
|
||||
if (params.devices) {
|
||||
if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) {
|
||||
@@ -227,7 +126,7 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||
}
|
||||
if (n_devs == 0) {
|
||||
LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__);
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: creating a Meta device with %zu devices\n", __func__, n_devs);
|
||||
for (size_t i = 0; i < n_devs; ++i) {
|
||||
@@ -265,7 +164,7 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||
}
|
||||
if (devs.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__);
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating a Meta device for tensor parallelism from %zu devices:\n", __func__, devs.size());
|
||||
@@ -347,8 +246,7 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||
} else {
|
||||
if (params.main_gpu >= (int)model->devices.size()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size());
|
||||
llama_model_free(model);
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
llama_device main_gpu = model->devices[params.main_gpu];
|
||||
model->devices.clear();
|
||||
@@ -365,7 +263,121 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||
props.memory_free/1024/1024);
|
||||
}
|
||||
|
||||
const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, *model, params);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
||||
static std::pair<int, llama_model *> llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud,
|
||||
const std::string & fname, std::vector<std::string> & splits, FILE * file, llama_model_params & params) {
|
||||
try {
|
||||
llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io,
|
||||
params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
|
||||
|
||||
ml.print_info();
|
||||
std::unique_ptr<llama_model> model_ptr(llama_model_create(ml, params));
|
||||
|
||||
bool ok = llama_prepare_model_devices(params, model_ptr.get());
|
||||
if (!ok) {
|
||||
return {-1, nullptr};
|
||||
}
|
||||
|
||||
auto * model = dynamic_cast<llama_model_base *>(model_ptr.get());
|
||||
if (model == nullptr) {
|
||||
GGML_ABORT("fatal error: model does not implement llama_model_base");
|
||||
}
|
||||
|
||||
// loading time will be recalculated after the first eval, so
|
||||
// we take page faults deferred by mmap() into consideration
|
||||
model->t_load_us = 0;
|
||||
time_meas tm(model->t_load_us);
|
||||
|
||||
model->t_start_us = tm.t_start_us;
|
||||
|
||||
model->hparams.vocab_only = params.vocab_only;
|
||||
model->hparams.no_alloc = params.no_alloc;
|
||||
|
||||
try {
|
||||
model->load_hparams(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
||||
}
|
||||
if (model->arch == LLM_ARCH_CLIP) {
|
||||
throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
|
||||
}
|
||||
try {
|
||||
model->load_vocab(ml);
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
model->load_stats(ml);
|
||||
model->print_info();
|
||||
|
||||
if (params.vocab_only) {
|
||||
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
|
||||
return {0, model_ptr.release()};
|
||||
}
|
||||
|
||||
if (!model->load_tensors(ml)) {
|
||||
return {-2, nullptr};
|
||||
}
|
||||
|
||||
return {0, model_ptr.release()};
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
||||
return {-1, nullptr};
|
||||
}
|
||||
}
|
||||
|
||||
static struct llama_model * llama_model_load_from_file_impl(
|
||||
struct gguf_context * metadata,
|
||||
llama_model_set_tensor_data_t set_tensor_data,
|
||||
void * set_tensor_data_ud,
|
||||
const std::string & path_model,
|
||||
std::vector<std::string> & splits,
|
||||
FILE * file,
|
||||
struct llama_model_params params) {
|
||||
{
|
||||
int n_sources_defined = 0;
|
||||
if (metadata != nullptr) {
|
||||
n_sources_defined++;
|
||||
}
|
||||
if (!path_model.empty()) {
|
||||
n_sources_defined++;
|
||||
}
|
||||
if (file != nullptr) {
|
||||
n_sources_defined++;
|
||||
}
|
||||
if (n_sources_defined != 1) {
|
||||
LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
ggml_time_init();
|
||||
|
||||
if (!params.vocab_only && ggml_backend_reg_count() == 0) {
|
||||
LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned cur_percentage = 0;
|
||||
if (params.progress_callback == NULL) {
|
||||
params.progress_callback_user_data = &cur_percentage;
|
||||
params.progress_callback = [](float progress, void * ctx) {
|
||||
unsigned * cur_percentage_p = (unsigned *) ctx;
|
||||
unsigned percentage = (unsigned) (100 * progress);
|
||||
while (percentage > *cur_percentage_p) {
|
||||
*cur_percentage_p = percentage;
|
||||
LLAMA_LOG_CONT(".");
|
||||
if (percentage >= 100) {
|
||||
LLAMA_LOG_CONT("\n");
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
}
|
||||
|
||||
const auto [status, model] = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, params);
|
||||
GGML_ASSERT(status <= 0);
|
||||
if (status < 0) {
|
||||
if (status == -1) {
|
||||
@@ -374,7 +386,9 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||
LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
|
||||
}
|
||||
|
||||
llama_model_free(model);
|
||||
if (model) {
|
||||
llama_model_free(model);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
+107
-1
@@ -1,6 +1,112 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
void llama_model_afmoe::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
|
||||
// Set up interleaved sliding window attention (ISWA)
|
||||
// Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4)
|
||||
if (hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
uint32_t swa_period = 4;
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
|
||||
hparams.set_swa_pattern(swa_period);
|
||||
|
||||
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
|
||||
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
}
|
||||
|
||||
// Default to sigmoid if not set
|
||||
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 56: type = LLM_TYPE_6B; break;
|
||||
case 32: type = LLM_TYPE_26B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_afmoe::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
const int64_t n_expert_shared = hparams.n_expert_shared;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
// dual attention normalization
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// attention projections
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
// Q/K normalization
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
// attention gating
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
|
||||
// dual ffn normalization
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) {
|
||||
// MoE layers
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
|
||||
|
||||
// grouped expert weights
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
// shared expert
|
||||
if (n_expert_shared > 0) {
|
||||
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
}
|
||||
} else {
|
||||
// Dense layers
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_afmoe::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
|
||||
+57
-1
@@ -1,6 +1,62 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
void llama_model_apertus::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer);
|
||||
ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer);
|
||||
ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer);
|
||||
ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_8B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_apertus::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
} else {
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
|
||||
// optional bias tensors
|
||||
layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
|
||||
|
||||
// Q and K layernorms for Apertus
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_apertus::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
+46
-1
@@ -1,6 +1,51 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
void llama_model_arcee::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
// Arcee uses the same structure as Llama
|
||||
switch (hparams.n_layer) {
|
||||
case 36: type = LLM_TYPE_4B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_arcee::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_arcee::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
+54
-1
@@ -1,6 +1,59 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
void llama_model_arctic::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
if (hparams.n_expert == 128) {
|
||||
switch (hparams.n_layer) {
|
||||
case 35: type = LLM_TYPE_10B_128x3_66B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} else {
|
||||
type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_arctic::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_arctic::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
+117
-1
@@ -1,7 +1,123 @@
|
||||
#include "models.h"
|
||||
|
||||
void llama_model_arwkv7::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false);
|
||||
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
|
||||
ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay);
|
||||
ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix);
|
||||
ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false);
|
||||
ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
|
||||
|
||||
llm_build_arwkv7::llm_build_arwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) {
|
||||
switch (hparams.n_layer) {
|
||||
case 12:
|
||||
switch (hparams.n_embd) {
|
||||
case 768: type = LLM_TYPE_190M; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
case 24:
|
||||
switch (hparams.n_embd) {
|
||||
case 1024: type = LLM_TYPE_450M; break;
|
||||
case 2048: type = LLM_TYPE_1_5B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
case 28:
|
||||
switch (hparams.n_embd) {
|
||||
case 1536: type = LLM_TYPE_1_5B; break;
|
||||
case 3584: type = LLM_TYPE_7B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
case 32:
|
||||
switch (hparams.n_embd) {
|
||||
case 2560: type = LLM_TYPE_2_9B; break;
|
||||
case 4096: type = LLM_TYPE_7B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
case 61:
|
||||
switch (hparams.n_embd) {
|
||||
case 4096: type = LLM_TYPE_14B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_arwkv7::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
const int n_lora_decay = hparams.n_lora_decay;
|
||||
const int n_lora_iclr = hparams.n_lora_iclr;
|
||||
const int n_lora_value_res_mix = hparams.n_lora_value_res_mix;
|
||||
const int n_lora_gate = hparams.n_lora_gate;
|
||||
const int attn_hidden_size = n_embd;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0);
|
||||
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0);
|
||||
|
||||
layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
|
||||
if (i == 0) {
|
||||
// actually not used
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0);
|
||||
} else {
|
||||
layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0);
|
||||
layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0);
|
||||
}
|
||||
|
||||
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
try {
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
|
||||
} catch(std::runtime_error & e) {
|
||||
// ARWKV models may not have gate tensors
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
|
||||
}
|
||||
|
||||
layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);
|
||||
layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0);
|
||||
|
||||
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
|
||||
layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_arwkv7::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) {
|
||||
GGML_ASSERT(n_embd == hparams.n_embd_r());
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
+44
-1
@@ -1,6 +1,49 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
void llama_model_baichuan::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_7B; break;
|
||||
case 40: type = LLM_TYPE_13B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
if (type == LLM_TYPE_13B) {
|
||||
// TODO: become GGUF KV parameter
|
||||
hparams.f_max_alibi_bias = 8.0f;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_baichuan::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
{
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_baichuan::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
@@ -1,6 +1,65 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
void llama_model_bailingmoe::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 28: type = LLM_TYPE_16B; break;
|
||||
case 88: type = LLM_TYPE_290B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_bailingmoe::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
const int64_t n_expert_shared = hparams.n_expert_shared;
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
|
||||
if (n_expert == 0) {
|
||||
throw std::runtime_error("n_expert must be > 0");
|
||||
}
|
||||
if (n_expert_used == 0) {
|
||||
throw std::runtime_error("n_expert_used must be > 0");
|
||||
}
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_bailingmoe::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
|
||||
@@ -1,6 +1,100 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) :
|
||||
void llama_model_bailingmoe2::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 20: type = LLM_TYPE_16B_A1B; break;
|
||||
case 21: type = LLM_TYPE_16B_A1B; break;
|
||||
case 32: type = LLM_TYPE_100B_A6B; break;
|
||||
case 33: type = LLM_TYPE_100B_A6B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
const int64_t n_expert_shared = hparams.n_expert_shared;
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2");
|
||||
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2");
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
int flags = 0;
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
// skip all tensors in the NextN layers
|
||||
flags |= TENSOR_SKIP;
|
||||
}
|
||||
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags);
|
||||
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers
|
||||
const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared;
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags);
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
|
||||
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
|
||||
} else { // Dense layers
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
|
||||
}
|
||||
|
||||
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
|
||||
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
|
||||
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
|
||||
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags);
|
||||
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_bailingmoe2::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
|
||||
+78
-1
@@ -1,6 +1,83 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
void llama_model_bert::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 3:
|
||||
type = LLM_TYPE_17M; break; // bge-micro
|
||||
case 6:
|
||||
type = LLM_TYPE_22M; break; // MiniLM-L6
|
||||
case 12:
|
||||
switch (hparams.n_embd) {
|
||||
case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small
|
||||
case 768: type = LLM_TYPE_109M; break; // bge-base
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
} break;
|
||||
case 24:
|
||||
type = LLM_TYPE_335M; break; // bge-large
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_bert::load_arch_tensors(llama_model_loader &) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
if (n_token_types == 0) {
|
||||
throw std::runtime_error(arch_name() + " model needs to define token type count");
|
||||
}
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
if (arch == LLM_ARCH_BERT) {
|
||||
pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
|
||||
|
||||
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
||||
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
|
||||
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0);
|
||||
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0);
|
||||
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0);
|
||||
|
||||
if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) {
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
} else {
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
if (arch == LLM_ARCH_NOMIC_BERT) {
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_bert::build_arch_graph(const llm_graph_params & params) const {
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
llama_model_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user