forked from wylab/llama.cpp
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| af237f3026 | |||
| 1a5631beaa | |||
| 1dab5f5a44 | |||
| c96f608d98 | |||
| 0842b9b465 | |||
| 59db9a357d | |||
| 23fbfcb1ad | |||
| e22cd0aa15 | |||
| 96cfc4992c | |||
| ed0007aa32 | |||
| 344ee2a38a | |||
| d6e1556499 | |||
| f76565db92 | |||
| 43e1cbd6c1 | |||
| 107d599952 | |||
| e8bbc736cb | |||
| b518195101 | |||
| e2763a6723 | |||
| 0beb8db3a0 | |||
| b2f460bd3c | |||
| 5f4cdac385 | |||
| ae87863dc1 | |||
| 97c64fbdbd | |||
| d417bc43dd | |||
| 35bee031e1 | |||
| 451ef08432 | |||
| 9b24886f78 | |||
| 62b8143ad2 | |||
| d088d5b74f | |||
| cd18a50ea5 |
@@ -39,6 +39,7 @@ Before submitting your PR:
|
||||
- For intricate features, consider opening a feature request first to discuss and align expectations
|
||||
- When adding support for a new model or feature, focus on **CPU support only** in the initial PR unless you have a good reason not to. Add support for other backends like CUDA in follow-up PRs
|
||||
- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly
|
||||
- If you are a new contributor, limit your open PRs to 1.
|
||||
|
||||
After submitting your PR:
|
||||
- Expect requests for modifications to ensure the code meets llama.cpp's standards for quality and long-term maintainability
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "chat-auto-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
@@ -51,13 +52,15 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
bool has_tools =
|
||||
autoparser.tools.format.mode != tool_format::NONE && inputs.tools.is_array() && !inputs.tools.empty();
|
||||
std::string trigger_marker = !autoparser.tools.format.section_start.empty() ? autoparser.tools.format.section_start :
|
||||
autoparser.tools.format.per_call_start;
|
||||
bool include_grammar =
|
||||
has_tools && ((inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO && !trigger_marker.empty()) ||
|
||||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED);
|
||||
autoparser.tools.format.per_call_start;
|
||||
|
||||
bool has_response_format = !inputs.json_schema.empty() && inputs.json_schema.is_object();
|
||||
bool include_grammar = has_response_format || (has_tools &&
|
||||
((inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO && !trigger_marker.empty()) ||
|
||||
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar_lazy = !has_response_format && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
@@ -68,7 +71,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
});
|
||||
|
||||
// Set grammar triggers based on tool section markers (fall back to per-call markers)
|
||||
if (data.grammar_lazy) { // only do triggers on lazy grammar
|
||||
if (data.grammar_lazy) {
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, trigger_marker }
|
||||
};
|
||||
@@ -87,7 +90,7 @@ common_peg_arena autoparser::build_parser(const templates_params & inputs) const
|
||||
// pre-register a json-string rule that accepts both quote styles. This must happen
|
||||
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
|
||||
if (tools.format.uses_python_dicts) {
|
||||
p.rule("json-string", [&]() { return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); });
|
||||
p.rule("json-string", p.quoted_string());
|
||||
}
|
||||
|
||||
parser_build_context ctx(p, inputs);
|
||||
@@ -104,8 +107,11 @@ common_peg_arena autoparser::build_parser(const templates_params & inputs) const
|
||||
bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
|
||||
if (has_response_format) {
|
||||
return ctx.reasoning_parser + p.space() +
|
||||
p.content(p.schema(p.json(), "response-format", inputs.json_schema)) + p.end();
|
||||
auto response_format = p.rule("response-format", p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema)));
|
||||
return ctx.reasoning_parser + p.space() + p.choice({
|
||||
p.literal("```json") + p.space() + response_format + p.space() + p.literal("```"),
|
||||
response_format
|
||||
}) + p.end();
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) {
|
||||
|
||||
@@ -162,7 +162,7 @@ diff_split calculate_diff_split(const std::string & left, const std::string & ri
|
||||
right_fully_consumed = true;
|
||||
}
|
||||
|
||||
auto eat_segment = [](std::string & str, segment & seg) -> std::string { return str.append(seg.value); };
|
||||
auto eat_segment = [](std::string str, const segment & seg) -> std::string { return std::move(str) + seg.value; };
|
||||
|
||||
bool can_have_text_suffix = left_end->type == segment_type::TEXT && right_end->type == segment_type::TEXT;
|
||||
bool can_have_text_prefix = right_start->type == segment_type::TEXT && left_start->type == segment_type::TEXT;
|
||||
|
||||
+78
-11
@@ -167,8 +167,8 @@ void tag_based_peg_mapper::from_ast(const common_peg_ast_arena & arena, const co
|
||||
});
|
||||
}
|
||||
|
||||
tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & input, bool is_partial) const {
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & input, common_peg_parse_flags extra_flags) const {
|
||||
common_peg_parse_context ctx(input, flags | extra_flags);
|
||||
auto parse_result = arena.parse(ctx);
|
||||
|
||||
tag_based_peg_mapper mapper;
|
||||
@@ -179,11 +179,10 @@ tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & inp
|
||||
|
||||
tagged_parse_result tagged_peg_parser::parse_anywhere_and_extract(const std::string & input) const {
|
||||
if (input.empty()) {
|
||||
return parse_and_extract(input, false);
|
||||
return parse_and_extract(input);
|
||||
}
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
common_peg_parse_context ctx(input, false);
|
||||
ctx.debug = debug;
|
||||
common_peg_parse_context ctx(input, flags);
|
||||
auto parse_result = arena.parse(ctx, i);
|
||||
if (parse_result.success() || i == input.size() - 1) {
|
||||
tag_based_peg_mapper mapper;
|
||||
@@ -477,6 +476,74 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools(
|
||||
return force_tool_calls ? section : optional(section);
|
||||
}
|
||||
|
||||
// Python-style tool calls: name(arg1="value1", arg2=123)
|
||||
// Used only by LFM2 for now, so we don't merge it into autoparser
|
||||
common_peg_parser common_chat_peg_builder::python_style_tool_calls(
|
||||
const nlohmann::json & tools,
|
||||
bool parallel_tool_calls) {
|
||||
if (!tools.is_array() || tools.empty()) {
|
||||
return eps();
|
||||
}
|
||||
|
||||
auto tool_choices = choice();
|
||||
|
||||
for (const auto & tool_def : tools) {
|
||||
if (!tool_def.contains("function")) {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_def.at("function");
|
||||
std::string name = function.at("name");
|
||||
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
|
||||
|
||||
auto args = eps();
|
||||
if (params.contains("properties") && !params["properties"].empty()) {
|
||||
auto arg_choice = choice();
|
||||
for (const auto & el : params["properties"].items()) {
|
||||
const std::string & prop_name = el.key();
|
||||
const auto & prop_def = el.value();
|
||||
bool is_string_type = (prop_def.contains("type") && prop_def["type"] == "string");
|
||||
|
||||
auto arg_name_parser = literal(prop_name);
|
||||
|
||||
common_peg_parser arg_value_parser = eps();
|
||||
auto string_value_parser = choice({
|
||||
literal("\"") + tool_arg_string_value(string_content('"')) + literal("\""),
|
||||
literal("'") + tool_arg_string_value(string_content('\'')) + literal("'")
|
||||
});
|
||||
|
||||
if (is_string_type) {
|
||||
arg_value_parser = string_value_parser;
|
||||
} else {
|
||||
arg_value_parser = tool_arg_value(python_value());
|
||||
}
|
||||
|
||||
// Full argument: name="value" or name=value
|
||||
auto arg_rule = tool_arg(
|
||||
tool_arg_open(eps()) +
|
||||
tool_arg_name(arg_name_parser) +
|
||||
literal("=") +
|
||||
arg_value_parser +
|
||||
tool_arg_close(eps())
|
||||
);
|
||||
arg_choice |= arg_rule;
|
||||
}
|
||||
|
||||
args = arg_choice + zero_or_more("," + space() + arg_choice);
|
||||
}
|
||||
|
||||
auto tool_parser = tool(tool_open(tool_name(literal(name)) + literal("(")) +
|
||||
space() + tool_args(args) + space() + tool_close(literal(")"))
|
||||
);
|
||||
|
||||
tool_choices |= rule("tool-" + name, tool_parser);
|
||||
}
|
||||
|
||||
if (parallel_tool_calls) {
|
||||
return "[" + space() + tool_choices + zero_or_more("," + space() + tool_choices) + space() + "]";
|
||||
}
|
||||
return "[" + space() + tool_choices + space() + "]";
|
||||
}
|
||||
|
||||
// Helper: Parse dot notation key into prefix and field name
|
||||
static std::pair<std::string, std::string> parse_key_spec(const std::string & key) {
|
||||
auto dot_pos = key.find('.');
|
||||
@@ -510,7 +577,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
|
||||
if (!call_id_key.empty()) {
|
||||
auto id_parser = atomic(
|
||||
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\"")
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\"")
|
||||
);
|
||||
inner_fields.push_back(optional(id_parser + space() + optional(literal(",") + space())));
|
||||
}
|
||||
@@ -519,7 +586,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
|
||||
auto gen_id_parser = atomic(
|
||||
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
@@ -608,7 +675,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
|
||||
if (id_spec.first.empty()) {
|
||||
auto id_parser = atomic(
|
||||
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\"")
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\"")
|
||||
);
|
||||
tool_parser_body = tool_parser_body + optional(id_parser + space() + literal(",") + space());
|
||||
}
|
||||
@@ -620,7 +687,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
|
||||
auto gen_id_parser = atomic(
|
||||
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
@@ -669,7 +736,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
|
||||
id_parser = atomic(
|
||||
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
@@ -680,7 +747,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
|
||||
gen_id_parser = atomic(
|
||||
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
|
||||
@@ -112,6 +112,11 @@ class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
bool parallel_tool_calls,
|
||||
bool force_tool_calls);
|
||||
|
||||
// Helper for Python-style function call format: name(arg1="value1", arg2=123)
|
||||
// Used by LFM2 and similar templates
|
||||
common_peg_parser python_style_tool_calls(const nlohmann::json & tools,
|
||||
bool parallel_tool_calls);
|
||||
|
||||
private:
|
||||
// Implementation helpers for standard_json_tools — one per JSON tool call layout mode
|
||||
common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools,
|
||||
@@ -155,19 +160,19 @@ struct tagged_parse_result {
|
||||
|
||||
struct tagged_peg_parser {
|
||||
common_peg_arena arena;
|
||||
bool debug = false;
|
||||
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_NONE;
|
||||
|
||||
tagged_peg_parser & withDebug() {
|
||||
debug = true;
|
||||
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
|
||||
return *this;
|
||||
}
|
||||
|
||||
tagged_peg_parser & withoutDebug() {
|
||||
debug = false;
|
||||
flags = flags & ~COMMON_PEG_PARSE_FLAG_DEBUG;
|
||||
return *this;
|
||||
}
|
||||
|
||||
tagged_parse_result parse_and_extract(const std::string & input, bool is_partial = false) const;
|
||||
tagged_parse_result parse_and_extract(const std::string & input, common_peg_parse_flags extra_flags = COMMON_PEG_PARSE_FLAG_NONE) const;
|
||||
tagged_parse_result parse_anywhere_and_extract(const std::string & input) const;
|
||||
};
|
||||
|
||||
|
||||
+107
-4
@@ -1274,8 +1274,95 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||
return data;
|
||||
}
|
||||
|
||||
// LFM2 format:
|
||||
// - Reasoning: <think>{reasoning}</think> (optional, only if enable_thinking is true)
|
||||
// - Content: text after reasoning (optional)
|
||||
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
|
||||
// Tool calls can appear multiple times (parallel tool calls)
|
||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
|
||||
const autoparser::templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
"<|tool_list_start|>",
|
||||
"<|tool_list_end|>",
|
||||
"<|tool_call_start|>",
|
||||
"<|tool_call_end|>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
|
||||
const std::string TOOL_CALL_START = "<|tool_call_start|>";
|
||||
const std::string TOOL_CALL_END = "<|tool_call_end|>";
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning && inputs.enable_thinking) {
|
||||
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
p.trigger_rule("tool-call", p.literal(TOOL_CALL_START) +
|
||||
p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls) +
|
||||
p.literal(TOOL_CALL_END)
|
||||
)
|
||||
);
|
||||
|
||||
auto content = p.content(p.until(TOOL_CALL_START));
|
||||
|
||||
return reasoning + content + tool_calls + end;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, TOOL_CALL_START }
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
namespace workaround {
|
||||
|
||||
static void map_developer_role_to_system(json & messages) {
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("role")) {
|
||||
if (message["role"] == "developer") {
|
||||
message["role"] = "system";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// if first message is system and template does not support it, merge it with next message
|
||||
static void system_message_not_supported(json & messages) {
|
||||
if (!messages.empty() && messages.front().at("role") == "system") {
|
||||
@@ -1353,6 +1440,10 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
params.add_bos = tmpls->add_bos;
|
||||
params.add_eos = tmpls->add_eos;
|
||||
|
||||
if (src.find("<|channel|>") == std::string::npos) {
|
||||
// map developer to system for all models except for GPT-OSS
|
||||
workaround::map_developer_role_to_system(params.messages);
|
||||
}
|
||||
workaround::func_args_not_string(params.messages);
|
||||
|
||||
if (!tmpl.original_caps().supports_system_role) {
|
||||
@@ -1422,6 +1513,14 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
return common_chat_params_init_kimi_k2(tmpl, params);
|
||||
}
|
||||
|
||||
// LFM2 - uses <|tool_list_start|>/<|tool_list_end|> markers and <|tool_call_start|>[name(args)]<|tool_call_end|> format
|
||||
// Detection: template has "<|tool_list_start|>" and "<|tool_list_end|>" markers
|
||||
if (src.find("<|tool_list_start|>") != std::string::npos &&
|
||||
src.find("<|tool_list_end|>") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: LFM2\n");
|
||||
return common_chat_params_init_lfm2(tmpl, params);
|
||||
}
|
||||
|
||||
try {
|
||||
LOG_DBG("Using differential autoparser\n");
|
||||
struct autoparser::autoparser autoparser;
|
||||
@@ -1527,8 +1626,12 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
|
||||
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
ctx.debug = params.debug;
|
||||
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
|
||||
if (params.debug) {
|
||||
flags |= COMMON_PEG_PARSE_FLAG_DEBUG;
|
||||
}
|
||||
|
||||
common_peg_parse_context ctx(input, flags);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
if (result.fail()) {
|
||||
@@ -1541,7 +1644,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str());
|
||||
fflush(stderr);
|
||||
}
|
||||
@@ -1557,7 +1660,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str());
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
+16
-1
@@ -7,6 +7,7 @@ struct common_http_url {
|
||||
std::string user;
|
||||
std::string password;
|
||||
std::string host;
|
||||
int port;
|
||||
std::string path;
|
||||
};
|
||||
|
||||
@@ -47,6 +48,20 @@ static common_http_url common_http_parse_url(const std::string & url) {
|
||||
parts.host = rest;
|
||||
parts.path = "/";
|
||||
}
|
||||
|
||||
auto colon_pos = parts.host.find(':');
|
||||
|
||||
if (colon_pos != std::string::npos) {
|
||||
parts.port = std::stoi(parts.host.substr(colon_pos + 1));
|
||||
parts.host = parts.host.substr(0, colon_pos);
|
||||
} else if (parts.scheme == "http") {
|
||||
parts.port = 80;
|
||||
} else if (parts.scheme == "https") {
|
||||
parts.port = 443;
|
||||
} else {
|
||||
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
|
||||
}
|
||||
|
||||
return parts;
|
||||
}
|
||||
|
||||
@@ -68,7 +83,7 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||
}
|
||||
#endif
|
||||
|
||||
httplib::Client cli(parts.scheme + "://" + parts.host);
|
||||
httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port));
|
||||
|
||||
if (!parts.user.empty()) {
|
||||
cli.set_basic_auth(parts.user, parts.password);
|
||||
|
||||
+151
-168
@@ -349,7 +349,7 @@ struct parser_executor {
|
||||
auto pos = start_pos;
|
||||
for (auto i = 0u; i < p.literal.size(); ++i) {
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -364,7 +364,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_sequence_parser & p) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
LOG_DBG("%sSEQ start at %zu '%s' (%zu children)\n", debug_indent().c_str(), start_pos,
|
||||
debug_input_snippet(start_pos).c_str(), p.children.size());
|
||||
}
|
||||
@@ -375,26 +375,19 @@ struct parser_executor {
|
||||
|
||||
for (size_t i = 0; i < p.children.size(); i++) {
|
||||
const auto & child_id = p.children[i];
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ child %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str());
|
||||
}
|
||||
auto result = arena.parse(child_id, ctx, pos);
|
||||
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ child %zu: %s at %zu->%zu\n", debug_indent().c_str(), i,
|
||||
common_peg_parse_result_type_name(result.type), result.start, result.end);
|
||||
}
|
||||
|
||||
if (result.fail()) {
|
||||
ctx.parse_depth--;
|
||||
if (ctx.is_partial && result.end >= ctx.input.size()) {
|
||||
if (ctx.debug) {
|
||||
fprintf(stderr, "%sSEQ -> NEED_MORE (child failed at end)\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end,
|
||||
std::move(nodes));
|
||||
}
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ -> FAIL\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, result.end);
|
||||
@@ -406,7 +399,7 @@ struct parser_executor {
|
||||
|
||||
if (result.need_more_input()) {
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ -> NEED_MORE\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes));
|
||||
@@ -416,14 +409,14 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sSEQ -> SUCCESS at %zu->%zu\n", debug_indent().c_str(), start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes));
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_choice_parser & p) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE start at %zu '%s' (%zu options)\n", debug_indent().c_str(), start_pos,
|
||||
debug_input_snippet(start_pos).c_str(), p.children.size());
|
||||
}
|
||||
@@ -432,17 +425,17 @@ struct parser_executor {
|
||||
auto pos = start_pos;
|
||||
for (size_t i = 0; i < p.children.size(); i++) {
|
||||
const auto & child_id = p.children[i];
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str());
|
||||
}
|
||||
auto result = arena.parse(child_id, ctx, pos);
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i,
|
||||
common_peg_parse_result_type_name(result.type));
|
||||
}
|
||||
if (!result.fail()) {
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE -> %s (option %zu)\n", debug_indent().c_str(),
|
||||
common_peg_parse_result_type_name(result.type), i);
|
||||
}
|
||||
@@ -451,14 +444,14 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sCHOICE -> FAIL (no options matched)\n", debug_indent().c_str());
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_repetition_parser & p) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT start at %zu '%s' (min=%d, max=%d)\n", debug_indent().c_str(), start_pos,
|
||||
debug_input_snippet(start_pos).c_str(), p.min_count, p.max_count);
|
||||
}
|
||||
@@ -471,7 +464,7 @@ struct parser_executor {
|
||||
// Try to match up to max_count times (or unlimited if max_count is -1)
|
||||
while (p.max_count == -1 || match_count < p.max_count) {
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT: at end of input, count=%d\n", debug_indent().c_str(), match_count);
|
||||
}
|
||||
break;
|
||||
@@ -479,7 +472,7 @@ struct parser_executor {
|
||||
|
||||
auto result = arena.parse(p.child, ctx, pos);
|
||||
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT iter %d: %s at %zu->%zu, nodes=%zu\n", debug_indent().c_str(), match_count,
|
||||
common_peg_parse_result_type_name(result.type), result.start, result.end, result.nodes.size());
|
||||
fprintf(stderr, "%sREPEAT CHILD: %s\n", debug_indent().c_str(), arena.dump(p.child).c_str());
|
||||
@@ -488,7 +481,7 @@ struct parser_executor {
|
||||
if (result.success()) {
|
||||
// Prevent infinite loop on empty matches
|
||||
if (result.end == pos) {
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%s REPEAT: empty match, stopping\n", debug_indent().c_str());
|
||||
}
|
||||
break;
|
||||
@@ -509,7 +502,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT -> NEED_MORE (count=%d, nodes=%zu)\n", debug_indent().c_str(),
|
||||
match_count, nodes.size());
|
||||
}
|
||||
@@ -517,7 +510,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
// Child failed - stop trying
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT: child failed, stopping\n", debug_indent().c_str());
|
||||
}
|
||||
break;
|
||||
@@ -526,14 +519,14 @@ struct parser_executor {
|
||||
// Check if we got enough matches
|
||||
if (p.min_count > 0 && match_count < p.min_count) {
|
||||
ctx.parse_depth--;
|
||||
if (pos >= ctx.input.size() && ctx.is_partial) {
|
||||
if (ctx.debug) {
|
||||
if (pos >= ctx.input.size() && ctx.is_lenient()) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT -> NEED_MORE (not enough matches: %d < %d)\n", debug_indent().c_str(),
|
||||
match_count, p.min_count);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos, std::move(nodes));
|
||||
}
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT -> FAIL (not enough matches: %d < %d)\n", debug_indent().c_str(), match_count,
|
||||
p.min_count);
|
||||
}
|
||||
@@ -541,7 +534,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
ctx.parse_depth--;
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sREPEAT -> SUCCESS (count=%d, nodes=%zu)\n", debug_indent().c_str(), match_count,
|
||||
nodes.size());
|
||||
}
|
||||
@@ -576,7 +569,7 @@ struct parser_executor {
|
||||
auto result = common_parse_utf8_codepoint(ctx.input, start_pos);
|
||||
|
||||
if (result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos);
|
||||
@@ -615,7 +608,7 @@ struct parser_executor {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
|
||||
}
|
||||
// Not enough matches yet
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -656,7 +649,7 @@ struct parser_executor {
|
||||
|
||||
// Check if we got enough matches
|
||||
if (match_count < p.min_count) {
|
||||
if (pos >= ctx.input.size() && ctx.is_partial) {
|
||||
if (pos >= ctx.input.size() && ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
|
||||
@@ -665,32 +658,23 @@ struct parser_executor {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
|
||||
}
|
||||
|
||||
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos) {
|
||||
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos, const char delimiter) {
|
||||
++pos; // consume '\'
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos);
|
||||
}
|
||||
|
||||
switch (ctx.input[pos]) {
|
||||
case '"':
|
||||
case '\'':
|
||||
case '\\':
|
||||
case '/':
|
||||
case 'b':
|
||||
case 'f':
|
||||
case 'n':
|
||||
case 'r':
|
||||
case 't':
|
||||
++pos;
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
|
||||
case 'u':
|
||||
return handle_unicode_escape(ctx, start, pos);
|
||||
default:
|
||||
// Invalid escape sequence
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
|
||||
char c = ctx.input[pos];
|
||||
if (c == delimiter || c == '\\' || c == '/' || c == 'b' || c == 'f' || c == 'n' || c == 'r' || c == 't') {
|
||||
++pos;
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
|
||||
} else if (c == 'u') {
|
||||
return handle_unicode_escape(ctx, start, pos);
|
||||
} else {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -698,7 +682,7 @@ struct parser_executor {
|
||||
++pos; // consume 'u'
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos);
|
||||
@@ -711,20 +695,20 @@ struct parser_executor {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_json_string_parser & /* p */) {
|
||||
common_peg_parse_result operator()(const common_peg_string_parser & p) {
|
||||
auto pos = start_pos;
|
||||
|
||||
// Parse string content (without quotes)
|
||||
while (pos < ctx.input.size()) {
|
||||
char c = ctx.input[pos];
|
||||
|
||||
if (c == '"') {
|
||||
// Found closing quote - success (don't consume it)
|
||||
if (c == p.delimiter) {
|
||||
// Found closing delimiter - success (don't consume it)
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
|
||||
}
|
||||
|
||||
if (c == '\\') {
|
||||
auto result = handle_escape_sequence(ctx, start_pos, pos);
|
||||
auto result = handle_escape_sequence(ctx, start_pos, pos, p.delimiter);
|
||||
if (!result.success()) {
|
||||
return result;
|
||||
}
|
||||
@@ -732,7 +716,7 @@ struct parser_executor {
|
||||
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -747,49 +731,7 @@ struct parser_executor {
|
||||
}
|
||||
|
||||
// Reached end without finding closing quote
|
||||
if (!ctx.is_partial) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_python_dict_string_parser & /* p */) {
|
||||
auto pos = start_pos;
|
||||
|
||||
// Parse string content (without quotes)
|
||||
while (pos < ctx.input.size()) {
|
||||
char c = ctx.input[pos];
|
||||
|
||||
if (c == '\'') {
|
||||
// Found closing quote - success (don't consume it)
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
|
||||
}
|
||||
|
||||
if (c == '\\') {
|
||||
auto result = handle_escape_sequence(ctx, start_pos, pos);
|
||||
if (!result.success()) {
|
||||
return result;
|
||||
}
|
||||
} else {
|
||||
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_partial) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
}
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INVALID) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
|
||||
pos += utf8_result.bytes_consumed;
|
||||
}
|
||||
}
|
||||
|
||||
// Reached end without finding closing quote
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
@@ -807,7 +749,7 @@ struct parser_executor {
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
// Incomplete UTF-8 sequence
|
||||
if (!ctx.is_partial) {
|
||||
if (!ctx.is_lenient()) {
|
||||
// Input is complete but UTF-8 is incomplete = malformed
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
@@ -837,7 +779,7 @@ struct parser_executor {
|
||||
last_valid_pos = pos;
|
||||
}
|
||||
|
||||
if (last_valid_pos == ctx.input.size() && ctx.is_partial) {
|
||||
if (last_valid_pos == ctx.input.size() && ctx.is_lenient()) {
|
||||
// Reached the end of a partial stream, there might still be more input that we need to consume.
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, last_valid_pos);
|
||||
}
|
||||
@@ -876,7 +818,7 @@ struct parser_executor {
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_tag_parser & p) {
|
||||
// Parse the child
|
||||
if (ctx.debug) {
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "%sTAG: %s\n", debug_indent().c_str(), p.tag.c_str());
|
||||
}
|
||||
auto result = arena.parse(p.child, ctx, start_pos);
|
||||
@@ -995,8 +937,7 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_ref_parser> ||
|
||||
std::is_same_v<T, common_peg_until_parser> ||
|
||||
std::is_same_v<T, common_peg_literal_parser> ||
|
||||
std::is_same_v<T, common_peg_json_string_parser> ||
|
||||
std::is_same_v<T, common_peg_python_dict_string_parser> ||
|
||||
std::is_same_v<T, common_peg_string_parser> ||
|
||||
std::is_same_v<T, common_peg_chars_parser> ||
|
||||
std::is_same_v<T, common_peg_any_parser> ||
|
||||
std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1072,10 +1013,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", unbounded)";
|
||||
}
|
||||
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
return "JsonString()";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
return "PythonDictString()";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
|
||||
return "String(" + std::string(1, p.delimiter) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
|
||||
return "Until(" + string_join(p.delimiters, " | ") + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
@@ -1288,47 +1227,25 @@ common_peg_arena common_peg_parser_builder::build() {
|
||||
|
||||
// String primitives
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_string_content() {
|
||||
return wrap(arena_.add_parser(common_peg_json_string_parser{}));
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::single_quoted_string_content() {
|
||||
return wrap(arena_.add_parser(common_peg_python_dict_string_parser{}));
|
||||
common_peg_parser common_peg_parser_builder::string_content(char delimiter) {
|
||||
return wrap(arena_.add_parser(common_peg_string_parser{delimiter}));
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::double_quoted_string() {
|
||||
return rule("dq-string",
|
||||
[this]() { return sequence({ literal("\""), json_string_content(), literal("\""), space() }); });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::single_quoted_string() {
|
||||
return rule("sq-string",
|
||||
[this]() { return sequence({ literal("'"), single_quoted_string_content(), literal("'"), space() }); });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::flexible_string() {
|
||||
return rule("flexible-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
|
||||
}
|
||||
|
||||
// Generic helpers for object/array structure
|
||||
|
||||
common_peg_parser common_peg_parser_builder::generic_object(const std::string & name,
|
||||
const common_peg_parser & string_parser,
|
||||
const common_peg_parser & value_parser) {
|
||||
return rule(name, [this, string_parser, value_parser]() {
|
||||
auto ws = space();
|
||||
auto member = sequence({ string_parser, ws, literal(":"), ws, value_parser });
|
||||
auto members = sequence({ member, zero_or_more(sequence({ ws, literal(","), ws, member })) });
|
||||
return sequence({ literal("{"), ws, choice({ literal("}"), sequence({ members, ws, literal("}") }) }) });
|
||||
return rule("double-quoted-string", [this]() {
|
||||
return sequence({literal("\""), string_content('"'), literal("\""), space()});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::generic_array(const std::string & name,
|
||||
const common_peg_parser & value_parser) {
|
||||
return rule(name, [this, value_parser]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({ value_parser, zero_or_more(sequence({ literal(","), ws, value_parser })) });
|
||||
return sequence({ literal("["), ws, choice({ literal("]"), sequence({ elements, ws, literal("]") }) }) });
|
||||
common_peg_parser common_peg_parser_builder::single_quoted_string() {
|
||||
return rule("single-quoted-string", [this]() {
|
||||
return sequence({literal("'"), string_content('\''), literal("'"), space()});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::quoted_string() {
|
||||
return rule("quoted-string", [this]() {
|
||||
return choice({double_quoted_string(), single_quoted_string()});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1351,7 +1268,7 @@ common_peg_parser common_peg_parser_builder::json_number() {
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_string() {
|
||||
return rule("json-string", [this]() {
|
||||
return sequence({literal("\""), json_string_content(), literal("\""), space()});
|
||||
return sequence({literal("\""), string_content('"'), literal("\""), space()});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1368,11 +1285,36 @@ common_peg_parser common_peg_parser_builder::json_null() {
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_object() {
|
||||
return generic_object("json-object", json_string(), json());
|
||||
return rule("json-object", [this]() {
|
||||
auto ws = space();
|
||||
auto member = sequence({json_string(), ws, literal(":"), ws, json()});
|
||||
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
|
||||
return sequence({
|
||||
literal("{"),
|
||||
ws,
|
||||
choice({
|
||||
literal("}"),
|
||||
sequence({members, ws, literal("}")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_array() {
|
||||
return generic_array("json-array", json());
|
||||
return rule("json-array", [this]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
|
||||
return sequence({
|
||||
literal("["),
|
||||
ws,
|
||||
choice({
|
||||
literal("]"),
|
||||
sequence({elements, ws, literal("]")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json() {
|
||||
@@ -1389,7 +1331,9 @@ common_peg_parser common_peg_parser_builder::json() {
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_string() {
|
||||
return rule("python-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
|
||||
return rule("python-string", [this]() {
|
||||
return choice({double_quoted_string(), single_quoted_string()});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_number() {
|
||||
@@ -1397,24 +1341,63 @@ common_peg_parser common_peg_parser_builder::python_number() {
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_bool() {
|
||||
return rule("python-bool", [this]() { return sequence({ choice({ literal("True"), literal("False") }), space() }); });
|
||||
return rule("python-bool", [this]() {
|
||||
return sequence({
|
||||
choice({literal("True"), literal("False")}),
|
||||
space()
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_null() {
|
||||
return rule("python-none", [this]() { return sequence({ literal("None"), space() }); });
|
||||
return rule("python-none", [this]() {
|
||||
return sequence({literal("None"), space()});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_dict() {
|
||||
return generic_object("python-dict", python_string(), python_value());
|
||||
return rule("python-dict", [this]() {
|
||||
auto ws = space();
|
||||
auto member = sequence({python_string(), ws, literal(":"), ws, python_value()});
|
||||
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
|
||||
return sequence({
|
||||
literal("{"),
|
||||
ws,
|
||||
choice({
|
||||
literal("}"),
|
||||
sequence({members, ws, literal("}")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_array() {
|
||||
return generic_array("python-array", python_value());
|
||||
return rule("python-array", [this]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({python_value(), zero_or_more(sequence({literal(","), ws, python_value()}))});
|
||||
return sequence({
|
||||
literal("["),
|
||||
ws,
|
||||
choice({
|
||||
literal("]"),
|
||||
sequence({elements, ws, literal("]")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_value() {
|
||||
return rule("python-value", [this]() {
|
||||
return choice({ python_dict(), python_array(), python_string(), python_number(), python_bool(), python_null() });
|
||||
return choice({
|
||||
python_dict(),
|
||||
python_array(),
|
||||
python_string(),
|
||||
python_number(),
|
||||
python_bool(),
|
||||
python_null()
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1535,8 +1518,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_chars_parser> ||
|
||||
std::is_same_v<T, common_peg_space_parser> ||
|
||||
std::is_same_v<T, common_peg_any_parser> ||
|
||||
std::is_same_v<T, common_peg_json_string_parser> ||
|
||||
std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
std::is_same_v<T, common_peg_string_parser>) {
|
||||
// These parsers do not have any children
|
||||
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
|
||||
for (auto child : p.children) {
|
||||
@@ -1672,10 +1654,9 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return result + "{" + std::to_string(p.min_count) + "}";
|
||||
}
|
||||
return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
|
||||
const std::string delim(1, p.delimiter);
|
||||
return R"(( [^)" + delim + R"(\\] | "\\" ( [)" + delim + R"(\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
|
||||
if (p.delimiters.empty()) {
|
||||
return ".*";
|
||||
@@ -1805,10 +1786,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
{"min_count", p.min_count},
|
||||
{"max_count", p.max_count}
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
return json{{"type", "json_string"}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
return json{{ "type", "python_dict_string" }};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
|
||||
return json{{"type", "string"}, {"delimiter", std::string(1, p.delimiter)}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
|
||||
return json{{"type", "until"}, {"delimiters", p.delimiters}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
@@ -1935,11 +1914,15 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
if (type == "json_string") {
|
||||
return common_peg_json_string_parser{};
|
||||
}
|
||||
if (type == "python_dict_string") {
|
||||
return common_peg_python_dict_string_parser{};
|
||||
if (type == "string") {
|
||||
if (!j.contains("delimiter")) {
|
||||
throw std::runtime_error("string parser missing delimiter field.");
|
||||
}
|
||||
std::string delimiter = j["delimiter"];
|
||||
if (delimiter.empty()) {
|
||||
throw std::runtime_error("string parser delimiter is empty.");
|
||||
}
|
||||
return common_peg_string_parser{delimiter[0]};
|
||||
}
|
||||
if (type == "until") {
|
||||
if (!j.contains("delimiters") || !j["delimiters"].is_array()) {
|
||||
|
||||
+36
-22
@@ -139,22 +139,43 @@ struct common_peg_parse_result {
|
||||
bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; }
|
||||
};
|
||||
|
||||
enum common_peg_parse_flags {
|
||||
COMMON_PEG_PARSE_FLAG_NONE = 0,
|
||||
COMMON_PEG_PARSE_FLAG_LENIENT = 1 << 0,
|
||||
COMMON_PEG_PARSE_FLAG_DEBUG = 1 << 1,
|
||||
};
|
||||
|
||||
inline common_peg_parse_flags operator|(common_peg_parse_flags a, common_peg_parse_flags b) {
|
||||
return static_cast<common_peg_parse_flags>(int(a) | int(b));
|
||||
}
|
||||
|
||||
inline common_peg_parse_flags & operator|=(common_peg_parse_flags & a, common_peg_parse_flags b) {
|
||||
return a = a | b;
|
||||
}
|
||||
|
||||
inline common_peg_parse_flags operator&(common_peg_parse_flags a, common_peg_parse_flags b) {
|
||||
return static_cast<common_peg_parse_flags>(int(a) & int(b));
|
||||
}
|
||||
|
||||
inline common_peg_parse_flags operator~(common_peg_parse_flags a) {
|
||||
return static_cast<common_peg_parse_flags>(~int(a));
|
||||
}
|
||||
|
||||
struct common_peg_parse_context {
|
||||
std::string input;
|
||||
bool is_partial;
|
||||
bool debug = false; // Enable debug output for parser tracing
|
||||
common_peg_parse_flags flags;
|
||||
common_peg_ast_arena ast;
|
||||
|
||||
int parse_depth;
|
||||
|
||||
common_peg_parse_context()
|
||||
: is_partial(false), parse_depth(0) {}
|
||||
common_peg_parse_context(common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_NONE)
|
||||
: flags(flags), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input)
|
||||
: input(input), is_partial(false), parse_depth(0) {}
|
||||
common_peg_parse_context(const std::string & input, common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_NONE)
|
||||
: input(input), flags(flags), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input, bool is_partial)
|
||||
: input(input), is_partial(is_partial), parse_depth(0) {}
|
||||
bool is_lenient() const { return flags & COMMON_PEG_PARSE_FLAG_LENIENT; }
|
||||
bool is_debug() const { return flags & COMMON_PEG_PARSE_FLAG_DEBUG; }
|
||||
};
|
||||
|
||||
class common_peg_arena;
|
||||
@@ -210,8 +231,9 @@ struct common_peg_chars_parser {
|
||||
int max_count; // -1 for unbounded
|
||||
};
|
||||
|
||||
struct common_peg_json_string_parser {};
|
||||
struct common_peg_python_dict_string_parser {};
|
||||
struct common_peg_string_parser {
|
||||
char delimiter;
|
||||
};
|
||||
|
||||
struct common_peg_until_parser {
|
||||
std::vector<std::string> delimiters;
|
||||
@@ -259,8 +281,7 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_any_parser,
|
||||
common_peg_space_parser,
|
||||
common_peg_chars_parser,
|
||||
common_peg_json_string_parser,
|
||||
common_peg_python_dict_string_parser,
|
||||
common_peg_string_parser,
|
||||
common_peg_until_parser,
|
||||
common_peg_schema_parser,
|
||||
common_peg_rule_parser,
|
||||
@@ -319,10 +340,6 @@ class common_peg_parser_builder {
|
||||
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
|
||||
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
|
||||
|
||||
// Generic helpers for building object/array structures with configurable string/value parsers.
|
||||
common_peg_parser generic_object(const std::string & name, const common_peg_parser & string_parser, const common_peg_parser & value_parser);
|
||||
common_peg_parser generic_array(const std::string & name, const common_peg_parser & value_parser);
|
||||
|
||||
public:
|
||||
common_peg_parser_builder();
|
||||
|
||||
@@ -423,13 +440,10 @@ class common_peg_parser_builder {
|
||||
common_peg_parser single_quoted_string();
|
||||
|
||||
// Matches a string that accepts both double-quoted and single-quoted styles.
|
||||
common_peg_parser flexible_string();
|
||||
common_peg_parser quoted_string();
|
||||
|
||||
// Matches double-quoted string content without the surrounding quotes.
|
||||
common_peg_parser json_string_content();
|
||||
|
||||
// Matches single-quoted string content without the surrounding quotes.
|
||||
common_peg_parser single_quoted_string_content();
|
||||
// Matches string content without the surrounding delimiter.
|
||||
common_peg_parser string_content(char delimiter);
|
||||
|
||||
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
|
||||
// value -> object | array | string | number | true | false | null
|
||||
|
||||
+9
-8
@@ -37,16 +37,17 @@ Legend:
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
@@ -54,7 +55,7 @@ Legend:
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
@@ -90,9 +91,9 @@ Legend:
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
@@ -100,7 +101,7 @@ Legend:
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
@@ -116,5 +117,5 @@ Legend:
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
||||
+2016
-7151
File diff suppressed because it is too large
Load Diff
@@ -202,8 +202,9 @@
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+1202
-3
File diff suppressed because it is too large
Load Diff
@@ -28,13 +28,17 @@ template <int K, int N> struct block {
|
||||
// control size
|
||||
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
||||
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
||||
static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding");
|
||||
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
||||
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
||||
static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding");
|
||||
|
||||
using block_q4_0x4 = block<4, 4>;
|
||||
using block_q4_0x8 = block<4, 8>;
|
||||
using block_q4_0x16 = block<4, 16>;
|
||||
using block_q8_0x4 = block<8, 4>;
|
||||
using block_q8_0x8 = block<8, 8>;
|
||||
using block_q8_0x16 = block<8, 16>;
|
||||
|
||||
struct block_q4_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
@@ -44,7 +48,14 @@ struct block_q4_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
||||
struct block_q4_Kx16 {
|
||||
ggml_half d[16]; // super-block scale for quantized scales
|
||||
ggml_half dmin[16]; // super-block scale for quantized mins
|
||||
uint8_t scales[192]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[2048]; // 4--bit quants
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding");
|
||||
struct block_q2_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
@@ -53,6 +64,13 @@ struct block_q2_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
||||
struct block_q2_Kx16 {
|
||||
ggml_half d[16]; // Super-block scale for quantized scales
|
||||
ggml_half dmin[16]; // Super-block scale for quantized mins
|
||||
uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks)
|
||||
uint8_t qs[1024]; // Data (16 cols * 64 bytes per block)
|
||||
};
|
||||
static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding");
|
||||
|
||||
struct block_q5_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
@@ -97,6 +115,12 @@ struct block_iq4_nlx8 {
|
||||
|
||||
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
|
||||
|
||||
struct block_iq4_nlx16 {
|
||||
ggml_half d[16]; // deltas for 16 iq4_nl blocks
|
||||
uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding");
|
||||
struct block_mxfp4x4 {
|
||||
uint8_t e[4];
|
||||
uint8_t qs[QK_MXFP4 * 2];
|
||||
@@ -109,7 +133,6 @@ struct block_mxfp4x8 {
|
||||
};
|
||||
static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
|
||||
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -132,6 +155,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -146,10 +171,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#endif
|
||||
|
||||
// Native implementations
|
||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
@@ -170,6 +207,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -184,10 +223,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
|
||||
@@ -205,7 +205,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
|
||||
|
||||
int64_t total_vram = 0;
|
||||
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
|
||||
total_vram += prop.totalGlobalMem;
|
||||
}
|
||||
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices (Total VRAM: %zu MiB):\n",
|
||||
__func__, info.device_count, (size_t)(total_vram / (1024 * 1024)));
|
||||
total_vram = 0;
|
||||
|
||||
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
@@ -243,6 +250,12 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
#else
|
||||
info.devices[id].supports_cooperative_launch = false;
|
||||
#endif // !(GGML_USE_MUSA)
|
||||
|
||||
// cudaMemGetInfo returns info for the current device
|
||||
size_t free_mem;
|
||||
CUDA_CHECK(cudaSetDevice(id));
|
||||
CUDA_CHECK(cudaMemGetInfo(&free_mem, NULL));
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
||||
|
||||
@@ -257,22 +270,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
info.devices[id].cc += prop.minor * 0x10;
|
||||
}
|
||||
}
|
||||
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
|
||||
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB (%zu MiB free)\n",
|
||||
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
|
||||
device_vmm ? "yes" : "no", prop.warpSize);
|
||||
device_vmm ? "yes" : "no", prop.warpSize,
|
||||
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
|
||||
info.devices[id].warp_size = 32;
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
|
||||
info.devices[id].cc += prop.minor * 0x10;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
|
||||
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
|
||||
#else
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
|
||||
(size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024));
|
||||
std::string device_name(prop.name);
|
||||
if (device_name == "NVIDIA GeForce MX450") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
@@ -4976,9 +4992,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
//TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327
|
||||
#ifdef GGML_USE_MUSA
|
||||
return false;
|
||||
#else
|
||||
return true;
|
||||
#endif // GGML_USE_MUSA
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
|
||||
@@ -75,6 +75,10 @@ struct ggml_metal {
|
||||
// abort ggml_metal_graph_compute if callback returns true
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_data;
|
||||
|
||||
// error state - set when a command buffer fails during synchronize
|
||||
// once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated
|
||||
bool has_error;
|
||||
};
|
||||
|
||||
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||
@@ -158,6 +162,8 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||
res->capture_started = false;
|
||||
res->capture_scope = nil;
|
||||
|
||||
res->has_error = false;
|
||||
|
||||
res->gf = nil;
|
||||
res->encode_async = nil;
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
@@ -246,7 +252,8 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
ctx->has_error = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,7 +269,15 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
// release this and all remaining command buffers before returning
|
||||
for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) {
|
||||
[ctx->cmd_bufs_ext[j] release];
|
||||
}
|
||||
[ctx->cmd_bufs_ext removeAllObjects];
|
||||
|
||||
ctx->has_error = true;
|
||||
return;
|
||||
}
|
||||
|
||||
[cmd_buf release];
|
||||
@@ -414,6 +429,11 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con
|
||||
}
|
||||
|
||||
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
|
||||
if (ctx->has_error) {
|
||||
GGML_LOG_ERROR("%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\n", __func__);
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
|
||||
// number of nodes encoded by the main thread (empirically determined)
|
||||
const int n_main = MAX(64, 0.1*gf->n_nodes);
|
||||
|
||||
|
||||
@@ -1717,12 +1717,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
|
||||
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
||||
|
||||
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
|
||||
if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type));
|
||||
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
||||
snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type));
|
||||
} else {
|
||||
snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type));
|
||||
}
|
||||
snprintf(name, 256, "%s_aa=%d", base, antialias);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
@@ -1108,7 +1108,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_POOL_1D:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_POOL_2D:
|
||||
|
||||
@@ -83,6 +83,7 @@
|
||||
#define FC_UNARY 1200
|
||||
#define FC_BIN 1300
|
||||
#define FC_SUM_ROWS 1400
|
||||
#define FC_UPSCALE 1500
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||
@@ -890,6 +891,7 @@ typedef struct {
|
||||
float sf1;
|
||||
float sf2;
|
||||
float sf3;
|
||||
float poffs;
|
||||
} ggml_metal_kargs_upscale;
|
||||
|
||||
typedef struct {
|
||||
|
||||
@@ -1963,6 +1963,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
(
|
||||
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
|
||||
op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_BF16 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_0 ||
|
||||
@@ -1977,6 +1978,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
op->src[0]->type == GGML_TYPE_Q4_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q6_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q2_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q3_K ||
|
||||
false) && (ne11 >= 4 && ne11 <= 8)
|
||||
)
|
||||
)
|
||||
@@ -3729,32 +3732,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const float sf0 = (float)ne0/op->src[0]->ne[0];
|
||||
const float sf1 = (float)ne1/op->src[0]->ne[1];
|
||||
const float sf2 = (float)ne2/op->src[0]->ne[2];
|
||||
const float sf3 = (float)ne3/op->src[0]->ne[3];
|
||||
float sf0 = (float)ne0/op->src[0]->ne[0];
|
||||
float sf1 = (float)ne1/op->src[0]->ne[1];
|
||||
float sf2 = (float)ne2/op->src[0]->ne[2];
|
||||
float sf3 = (float)ne3/op->src[0]->ne[3];
|
||||
|
||||
const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
|
||||
|
||||
float poffs = 0.5f;
|
||||
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
poffs = 0.0f;
|
||||
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
||||
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
||||
}
|
||||
|
||||
ggml_metal_kargs_upscale args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.sf0 =*/ sf0,
|
||||
/*.sf1 =*/ sf1,
|
||||
/*.sf2 =*/ sf2,
|
||||
/*.sf3 =*/ sf3
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.sf0 =*/ sf0,
|
||||
/*.sf1 =*/ sf1,
|
||||
/*.sf2 =*/ sf2,
|
||||
/*.sf3 =*/ sf3,
|
||||
/*.poffs =*/ poffs,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
|
||||
|
||||
@@ -3481,6 +3481,13 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
|
||||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
@@ -3531,6 +3538,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
|
||||
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
|
||||
|
||||
template<typename T0, typename T1, short NR0, typename args_t>
|
||||
void kernel_mul_mv_t_t_impl(
|
||||
args_t args,
|
||||
@@ -4530,7 +4547,9 @@ kernel void kernel_conv_transpose_2d<half>(
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
kernel void kernel_upscale_f32(
|
||||
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
|
||||
|
||||
kernel void kernel_upscale_nearest_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
@@ -4556,6 +4575,156 @@ kernel void kernel_upscale_f32(
|
||||
}
|
||||
}
|
||||
|
||||
static inline float bilinear_tri(float x) {
|
||||
return MAX(0.0f, 1.0f - fabs(x));
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_bilinear_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3 / args.sf3;
|
||||
const int64_t i02 = i2 / args.sf2;
|
||||
|
||||
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
||||
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
|
||||
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
|
||||
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
|
||||
|
||||
src0 += i03*args.nb03 + i02*args.nb02;
|
||||
|
||||
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
if (FC_upscale_aa) {
|
||||
const float support0 = MAX(1.0f, 1.0f / args.sf0);
|
||||
const float invscale0 = 1.0f / support0;
|
||||
const float support1 = MAX(1.0f, 1.0f / args.sf1);
|
||||
const float invscale1 = 1.0f / support1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
|
||||
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
|
||||
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
|
||||
|
||||
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
|
||||
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
|
||||
|
||||
float sum = 0.0f;
|
||||
float wsum = 0.0f;
|
||||
|
||||
for (int64_t sy = y_min; sy < y_max; ++sy) {
|
||||
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
|
||||
for (int64_t sx = x_min; sx < x_max; ++sx) {
|
||||
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
|
||||
const float w = wx * wy;
|
||||
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
||||
sum += (*src_ptr) * w;
|
||||
wsum += w;
|
||||
}
|
||||
}
|
||||
|
||||
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
|
||||
dst_ptr[i0] = v;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
|
||||
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
|
||||
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
|
||||
|
||||
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
|
||||
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
|
||||
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
|
||||
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
|
||||
|
||||
const float v =
|
||||
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
|
||||
(*src10) * fd0 * (1.0f - fd1) +
|
||||
(*src01) * (1.0f - fd0) * fd1 +
|
||||
(*src11) * fd0 * fd1;
|
||||
|
||||
dst_ptr[i0] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline float bicubic_weight1(float x) {
|
||||
const float a = -0.75f;
|
||||
return ((a + 2) * x - (a + 3)) * x * x + 1;
|
||||
}
|
||||
|
||||
static inline float bicubic_weight2(float x) {
|
||||
const float a = -0.75f;
|
||||
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_bicubic_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3 / args.sf3;
|
||||
const int64_t i02 = i2 / args.sf2;
|
||||
|
||||
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
||||
const int64_t i01 = (int64_t)floor(f01);
|
||||
const float fd1 = f01 - (float)i01;
|
||||
|
||||
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
|
||||
const float w_y1 = bicubic_weight1(fd1);
|
||||
const float w_y2 = bicubic_weight1(1.0f - fd1);
|
||||
const float w_y3 = bicubic_weight2(2.0f - fd1);
|
||||
|
||||
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
|
||||
|
||||
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
const int64_t i00 = (int64_t)floor(f00);
|
||||
const float fd0 = f00 - (float)i00;
|
||||
|
||||
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
|
||||
const float w_x1 = bicubic_weight1(fd0);
|
||||
const float w_x2 = bicubic_weight1(1.0f - fd0);
|
||||
const float w_x3 = bicubic_weight2(2.0f - fd0);
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int dy = -1; dy <= 2; ++dy) {
|
||||
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
|
||||
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
|
||||
|
||||
for (int dx = -1; dx <= 2; ++dx) {
|
||||
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
|
||||
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
|
||||
|
||||
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
||||
sum += (*src_ptr) * wx * wy;
|
||||
}
|
||||
}
|
||||
|
||||
dst_ptr[i0] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_pad_f32(
|
||||
constant ggml_metal_kargs_pad & args,
|
||||
device const char * src0,
|
||||
|
||||
@@ -744,6 +744,7 @@ struct vk_device_struct {
|
||||
|
||||
// [src/dst 0=fp32,1=fp16]
|
||||
vk_pipeline pipeline_exp[2];
|
||||
vk_pipeline pipeline_elu[2];
|
||||
vk_pipeline pipeline_gelu[2];
|
||||
vk_pipeline pipeline_gelu_erf[2];
|
||||
vk_pipeline pipeline_gelu_quick[2];
|
||||
@@ -762,6 +763,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_ceil[2];
|
||||
vk_pipeline pipeline_floor[2];
|
||||
vk_pipeline pipeline_trunc[2];
|
||||
vk_pipeline pipeline_sgn[2];
|
||||
|
||||
vk_pipeline pipeline_add1_f16_f16;
|
||||
vk_pipeline pipeline_add1_f16_f32;
|
||||
@@ -4373,6 +4375,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
CREATE_UNARY(elu)
|
||||
CREATE_UNARY(gelu)
|
||||
CREATE_UNARY(gelu_erf)
|
||||
CREATE_UNARY(gelu_quick)
|
||||
@@ -4391,6 +4394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_UNARY(ceil)
|
||||
CREATE_UNARY(floor)
|
||||
CREATE_UNARY(trunc)
|
||||
CREATE_UNARY(sgn)
|
||||
#undef CREATE_UNARY
|
||||
|
||||
#define CREATE_UNARY_RTE(name) \
|
||||
@@ -9241,6 +9245,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
switch (ggml_get_unary_op(dst)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_ELU:
|
||||
return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_SILU:
|
||||
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_GELU:
|
||||
@@ -9277,6 +9283,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
|
||||
case GGML_UNARY_OP_SGN:
|
||||
return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16];
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -12852,6 +12860,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
}
|
||||
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
@@ -12870,6 +12879,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
ggml_vk_unary(ctx, compute_ctx, src0, node);
|
||||
break;
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
@@ -13248,6 +13258,10 @@ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, g
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t val32 = (uint32_t)value * 0x01010101;
|
||||
ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
|
||||
}
|
||||
@@ -13257,6 +13271,10 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
||||
}
|
||||
|
||||
@@ -13264,12 +13282,20 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
|
||||
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
if (ggml_nbytes(src) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ggml_backend_buffer_is_vk(src->buffer)) {
|
||||
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
@@ -13459,6 +13485,10 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context cpy_ctx;
|
||||
@@ -13502,6 +13532,10 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
||||
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
||||
@@ -13528,9 +13562,14 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
||||
}
|
||||
|
||||
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
|
||||
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
|
||||
|
||||
// Skip zero-size tensors
|
||||
if (ggml_nbytes(src) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) {
|
||||
return false;
|
||||
}
|
||||
@@ -14951,6 +14990,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
@@ -14969,6 +15009,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
@@ -16074,6 +16115,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
case GGML_UNARY_OP_EXP:
|
||||
tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_ELU:
|
||||
tensor_clone = ggml_elu(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
@@ -16132,6 +16176,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]);
|
||||
break;
|
||||
default:
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
float x = float(data_a[i]);
|
||||
|
||||
if (x < 0.0f) {
|
||||
x = exp(x) - 1;
|
||||
}
|
||||
|
||||
data_d[i] = D_TYPE(x);
|
||||
}
|
||||
@@ -377,6 +377,7 @@ void main() {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
barrier();
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
@@ -387,6 +388,7 @@ void main() {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -404,18 +406,22 @@ void main() {
|
||||
// Full coopMat is within bounds, but stride_d is not aligned
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
||||
// Partial coopMat is within bounds
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[i] = D_TYPE(sign(float(data_a[i])));
|
||||
}
|
||||
@@ -867,8 +867,12 @@ void process_shaders() {
|
||||
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("elu_f16", "elu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("sgn_f16", "sgn.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("sgn_f32", "sgn.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
@@ -177,6 +177,8 @@ class Keys:
|
||||
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
|
||||
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
|
||||
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
||||
KEY_LENGTH_SWA = "{arch}.attention.key_length_swa"
|
||||
VALUE_LENGTH_SWA = "{arch}.attention.value_length_swa"
|
||||
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
|
||||
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
|
||||
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
|
||||
@@ -188,6 +190,7 @@ class Keys:
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
DIMENSION_COUNT_SWA = "{arch}.rope.dimension_count_swa"
|
||||
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
|
||||
FREQ_BASE = "{arch}.rope.freq_base"
|
||||
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
|
||||
|
||||
@@ -773,6 +773,12 @@ class GGUFWriter:
|
||||
def add_value_length_mla(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
|
||||
|
||||
def add_key_length_swa(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KEY_LENGTH_SWA.format(arch=self.arch), length)
|
||||
|
||||
def add_value_length_swa(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_LENGTH_SWA.format(arch=self.arch), length)
|
||||
|
||||
def add_indexer_head_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count)
|
||||
|
||||
@@ -946,6 +952,9 @@ class GGUFWriter:
|
||||
def add_rope_dimension_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_rope_dimension_count_swa(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT_SWA.format(arch=self.arch), count)
|
||||
|
||||
def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
|
||||
self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{%- if tools -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "You can use the following tools: <|tool_list_start|>[" -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool is not string -%}
|
||||
{%- set tool = tool | tojson -%}
|
||||
@@ -17,7 +17,6 @@
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
|
||||
{{- '**IMPORTANT**: The syntax for calling the tools is: <|tool_call_start|>JSON tool call goes here<|tool_call_end|>. Please only call tools in the specified manner.' -}}
|
||||
{%- endif -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
|
||||
@@ -30,18 +29,9 @@
|
||||
{%- endif -%}
|
||||
{%- if message["role"] == "tool" -%}
|
||||
{%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
|
||||
{%- elif message["role"] == "assistant" -%}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '\n<|tool_call_start|>\n{"name": "' + tool_call.name + '", "arguments": ' + (tool_call.arguments if tool_call.arguments is string else tool_call.arguments | tojson) + '}\n<|tool_call_end|>\n' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- endif -%}
|
||||
{{- content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -1,37 +0,0 @@
|
||||
{{- bos_token -}}
|
||||
{%- set system_prompt = "" -%}
|
||||
{%- set ns = namespace(system_prompt="") -%}
|
||||
{%- if messages[0]["role"] == "system" -%}
|
||||
{%- set ns.system_prompt = messages[0]["content"] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{%- if tools -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool is not string -%}
|
||||
{%- set tool = tool | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + tool -%}
|
||||
{%- if not loop.last -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
|
||||
{%- endif -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message["role"] + "\n" -}}
|
||||
{%- set content = message["content"] -%}
|
||||
{%- if content is not string -%}
|
||||
{%- set content = content | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- if message["role"] == "tool" -%}
|
||||
{%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
|
||||
{%- endif -%}
|
||||
{{- content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
@@ -230,11 +230,14 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_SWA, "%s.attention.key_length_swa" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_SWA, "%s.attention.value_length_swa" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
|
||||
|
||||
@@ -234,11 +234,14 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_SWA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_SWA,
|
||||
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
|
||||
LLM_KV_ATTENTION_INDEXER_TOP_K,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_FREQ_BASE_SWA,
|
||||
|
||||
+12
-8
@@ -2876,19 +2876,23 @@ llama_context * llama_init_from_model(
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_k);
|
||||
if (model->hparams.n_embd_head_k % blck_size != 0) {
|
||||
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
||||
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
|
||||
return nullptr;
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
|
||||
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
||||
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_v);
|
||||
if (model->hparams.n_embd_head_v % blck_size != 0) {
|
||||
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
||||
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
|
||||
return nullptr;
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
|
||||
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n",
|
||||
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+4
-8
@@ -849,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
ubatch (params.ubatch),
|
||||
n_embd (hparams.n_embd),
|
||||
n_layer (hparams.n_layer),
|
||||
n_rot (hparams.n_rot),
|
||||
n_rot (hparams.n_rot()),
|
||||
n_ctx (cparams.n_ctx),
|
||||
n_head (hparams.n_head()),
|
||||
n_head_kv (hparams.n_head_kv()),
|
||||
n_embd_head_k (hparams.n_embd_head_k),
|
||||
n_embd_head_k (hparams.n_embd_head_k()),
|
||||
n_embd_k_gqa (hparams.n_embd_k_gqa()),
|
||||
n_embd_head_v (hparams.n_embd_head_v),
|
||||
n_embd_head_v (hparams.n_embd_head_v()),
|
||||
n_embd_v_gqa (hparams.n_embd_v_gqa()),
|
||||
n_expert (hparams.n_expert),
|
||||
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
|
||||
@@ -1151,7 +1151,6 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
// TODO remove redundant scale_w argument
|
||||
ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * gate_inp,
|
||||
@@ -1163,7 +1162,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
@@ -1180,7 +1178,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
n_expert_used,
|
||||
type_op,
|
||||
norm_w,
|
||||
scale_w,
|
||||
w_scale,
|
||||
gating_op,
|
||||
il,
|
||||
@@ -1204,7 +1201,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
@@ -1332,7 +1328,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
|
||||
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
if (scale_w) {
|
||||
if (w_scale != 0.0f && w_scale != 1.0f) {
|
||||
weights = ggml_scale(ctx0, weights, w_scale);
|
||||
cb(weights, "ffn_moe_weights_scaled", il);
|
||||
}
|
||||
|
||||
@@ -810,7 +810,6 @@ struct llm_graph_context {
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
@@ -832,7 +831,6 @@ struct llm_graph_context {
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
|
||||
+28
-4
@@ -62,6 +62,14 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const {
|
||||
return n_head/n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_rot(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return is_swa(il) ? n_rot_swa : n_rot_full;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_inp() const {
|
||||
uint32_t n_embd_inp = n_embd;
|
||||
|
||||
@@ -76,16 +84,32 @@ uint32_t llama_hparams::n_embd_out() const {
|
||||
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_k(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_v(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
return n_embd_head_k * n_head_kv;
|
||||
return n_embd_head_k(il) * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
return n_embd_head_v * n_head_kv;
|
||||
return n_embd_head_v(il) * n_head_kv;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_n_embd_k_gqa_variable() const {
|
||||
@@ -197,11 +221,11 @@ bool llama_hparams::is_mla() const {
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_k_mla() const {
|
||||
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
|
||||
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k();
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_v_mla() const {
|
||||
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
|
||||
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v();
|
||||
}
|
||||
|
||||
bool llama_hparams::has_kv(uint32_t il) const {
|
||||
|
||||
+16
-3
@@ -44,13 +44,20 @@ struct llama_hparams {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_layer;
|
||||
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
|
||||
uint32_t n_rot;
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_expert = 0;
|
||||
uint32_t n_expert_used = 0;
|
||||
uint32_t n_rel_attn_bkts = 0;
|
||||
|
||||
// different head size for full_attention and SWA layers
|
||||
uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_embd_head_k_swa;
|
||||
uint32_t n_embd_head_v_swa;
|
||||
|
||||
// different RoPE dimensions for full_attention and SWA layers
|
||||
uint32_t n_rot_full;
|
||||
uint32_t n_rot_swa;
|
||||
|
||||
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
||||
uint32_t n_embd_head_k_mla_impl = 0;
|
||||
uint32_t n_embd_head_v_mla_impl = 0;
|
||||
@@ -247,12 +254,18 @@ struct llama_hparams {
|
||||
|
||||
uint32_t n_gqa(uint32_t il = 0) const;
|
||||
|
||||
uint32_t n_rot(uint32_t il = 0) const;
|
||||
|
||||
// dimension of main + auxiliary input embeddings
|
||||
uint32_t n_embd_inp() const;
|
||||
|
||||
// dimension of output embeddings
|
||||
uint32_t n_embd_out() const;
|
||||
|
||||
// dimension of key/value embeddings for each head (per layer)
|
||||
uint32_t n_embd_head_k(uint32_t il = 0) const;
|
||||
uint32_t n_embd_head_v(uint32_t il = 0) const;
|
||||
|
||||
// dimension of key embeddings across all k-v heads
|
||||
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
|
||||
|
||||
|
||||
+14
-16
@@ -1033,8 +1033,8 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k
|
||||
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
||||
|
||||
return ggml_view_4d(ctx, k,
|
||||
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(k->type, hparams.n_embd_head_k),
|
||||
hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(k->type, hparams.n_embd_head_k(il)),
|
||||
ggml_row_size(k->type, n_embd_k_gqa),
|
||||
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
|
||||
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
|
||||
@@ -1056,8 +1056,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
|
||||
if (!v_trans) {
|
||||
// note: v->nb[1] <= v->nb[2]
|
||||
return ggml_view_4d(ctx, v,
|
||||
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
||||
hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1]
|
||||
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
|
||||
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
|
||||
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
|
||||
@@ -1065,8 +1065,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
|
||||
|
||||
// note: v->nb[1] > v->nb[2]
|
||||
return ggml_view_4d(ctx, v,
|
||||
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
|
||||
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
|
||||
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns,
|
||||
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1]
|
||||
ggml_row_size(v->type, kv_size), // v->nb[2]
|
||||
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
|
||||
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
|
||||
@@ -1544,7 +1544,8 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const {
|
||||
float freq_scale,
|
||||
uint32_t il) const {
|
||||
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
||||
|
||||
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
||||
@@ -1552,7 +1553,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
||||
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
||||
|
||||
const auto & n_rot = hparams.n_rot;
|
||||
const auto & n_rot = hparams.n_rot(il);
|
||||
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
||||
// @ngxson : this is a workaround
|
||||
// for M-RoPE, we want to rotate the whole vector when doing KV shift
|
||||
@@ -1606,13 +1607,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
auto * ctx = res->get_ctx();
|
||||
auto * gf = res->get_gf();
|
||||
|
||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
const auto & n_rot = hparams.n_rot;
|
||||
|
||||
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||
@@ -1626,6 +1620,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
||||
const auto n_rot = hparams.n_rot(il);
|
||||
const auto n_embd_head_k = hparams.n_embd_head_k(il);
|
||||
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
|
||||
|
||||
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
|
||||
@@ -1638,7 +1636,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
ggml_row_size(layer.k->type, n_embd_nope));
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
@@ -264,7 +264,8 @@ private:
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const;
|
||||
float freq_scale,
|
||||
uint32_t il) const;
|
||||
|
||||
ggml_cgraph * build_graph_shift(
|
||||
llm_graph_result * res,
|
||||
|
||||
@@ -918,7 +918,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int n_embd_head = hparams.n_embd_head_v;
|
||||
const int n_embd_head = hparams.n_embd_head_v();
|
||||
const int n_head = hparams.n_head();
|
||||
ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512);
|
||||
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
|
||||
|
||||
@@ -186,8 +186,10 @@ void llama_model_saver::add_kv_from_model() {
|
||||
add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
|
||||
add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
|
||||
add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
|
||||
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k);
|
||||
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v);
|
||||
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full);
|
||||
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full);
|
||||
add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
|
||||
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
|
||||
add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
@@ -199,7 +201,8 @@ void llama_model_saver::add_kv_from_model() {
|
||||
|
||||
const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
|
||||
|
||||
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
|
||||
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full);
|
||||
add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa);
|
||||
add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train);
|
||||
// add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name
|
||||
add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
|
||||
|
||||
+53
-39
@@ -459,26 +459,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
|
||||
// gpt-j n_rot = rotary_dim
|
||||
|
||||
hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
|
||||
hparams.n_embd_head_k_full = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false);
|
||||
|
||||
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
||||
hparams.n_embd_head_v_full = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false);
|
||||
|
||||
// sanity check for n_rot (optional)
|
||||
hparams.n_rot = hparams.n_embd_head_k;
|
||||
hparams.n_rot_full = hparams.n_embd_head_k_full;
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full, false);
|
||||
|
||||
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) {
|
||||
if (hparams.n_rot != hparams.n_embd_head_k) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
||||
if (hparams.n_rot_full != hparams.n_embd_head_k_full) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot_full, hparams.n_embd_head_k_full));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
hparams.n_rot = 0;
|
||||
hparams.n_embd_head_k = 0;
|
||||
hparams.n_embd_head_v = 0;
|
||||
hparams.n_rot_full = 0;
|
||||
hparams.n_embd_head_k_full = 0;
|
||||
hparams.n_embd_head_v_full = 0;
|
||||
}
|
||||
|
||||
// head size and n_rot for SWA layers
|
||||
{
|
||||
hparams.n_embd_head_k_swa = hparams.n_embd_head_k_full;
|
||||
hparams.n_embd_head_v_swa = hparams.n_embd_head_v_full;
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa, false);
|
||||
|
||||
hparams.n_rot_swa = hparams.n_rot_full;
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa, false);
|
||||
}
|
||||
|
||||
// for differentiating model types
|
||||
@@ -1114,10 +1125,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
// Load attention parameters
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
||||
} break;
|
||||
case LLM_ARCH_PLAMO3:
|
||||
{
|
||||
@@ -1212,7 +1219,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
|
||||
hparams.f_attention_scale = type == LLM_TYPE_27B
|
||||
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
@@ -1245,7 +1252,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
|
||||
hparams.f_attention_scale = type == LLM_TYPE_27B
|
||||
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
{
|
||||
@@ -1294,7 +1301,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
case 24: type = LLM_TYPE_0_3B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
@@ -1570,6 +1577,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
|
||||
switch (hparams.n_ff_exp) {
|
||||
case 1408: type = LLM_TYPE_16B; break;
|
||||
@@ -2076,6 +2084,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
@@ -2485,7 +2494,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl);
|
||||
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
|
||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda);
|
||||
|
||||
@@ -2516,6 +2524,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
|
||||
// full_attention layer only use half of the RoPE dimensions
|
||||
hparams.n_rot_full = hparams.n_rot_full / 2;
|
||||
|
||||
// MoE + SWA parameters
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||
@@ -2659,13 +2670,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k();
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v();
|
||||
const int64_t n_ff = hparams.n_ff();
|
||||
const int64_t n_embd_gqa = n_embd_v_gqa;
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_token_types = vocab.n_token_types();
|
||||
const int64_t n_rot = hparams.n_rot;
|
||||
const int64_t n_rot = hparams.n_rot();
|
||||
const int64_t n_expert = hparams.n_expert;
|
||||
const int64_t n_expert_used = hparams.n_expert_used;
|
||||
const int64_t n_ctx_train = hparams.n_ctx_train;
|
||||
@@ -2965,8 +2976,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_MINICPM3:
|
||||
{
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
|
||||
|
||||
const int64_t q_lora_rank = hparams.n_lora_q;
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
@@ -3838,8 +3849,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
|
||||
|
||||
// attention parameters
|
||||
const uint32_t qk_dim = hparams.n_embd_head_k;
|
||||
const uint32_t v_dim = hparams.n_embd_head_v;
|
||||
const uint32_t qk_dim = hparams.n_embd_head_k();
|
||||
const uint32_t v_dim = hparams.n_embd_head_v();
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
@@ -3899,8 +3910,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_PLAMO3:
|
||||
{
|
||||
const int64_t head_dim_q = hparams.n_embd_head_k;
|
||||
const int64_t head_dim_v = hparams.n_embd_head_v;
|
||||
const int64_t head_dim_q = hparams.n_embd_head_k();
|
||||
const int64_t head_dim_v = hparams.n_embd_head_v();
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
@@ -4647,7 +4658,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
{
|
||||
const uint32_t head_dim = hparams.n_embd_head_k;
|
||||
const uint32_t head_dim = hparams.n_embd_head_k();
|
||||
const int64_t n_qo_dim = n_head * head_dim;
|
||||
const int64_t n_kv_dim = n_head_kv * head_dim;
|
||||
|
||||
@@ -4876,7 +4887,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
|
||||
GGML_ASSERT(n_embd_head_qk_nope >= 1);
|
||||
|
||||
@@ -4955,8 +4966,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_PLM:
|
||||
{
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -5394,7 +5405,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
|
||||
|
||||
const int64_t q_lora_rank = hparams.n_lora_q;
|
||||
@@ -5678,7 +5689,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_expert = hparams.n_expert;
|
||||
const int64_t n_expert_used = hparams.n_expert_used;
|
||||
const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp;
|
||||
const int64_t head_dim = hparams.n_embd_head_k;
|
||||
const int64_t head_dim = hparams.n_embd_head_k();
|
||||
const int64_t n_qo_dim = n_head * head_dim;
|
||||
const int64_t n_kv_dim = n_head_kv * head_dim;
|
||||
|
||||
@@ -6966,7 +6977,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
// Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA)
|
||||
// Note: hparams.n_rot may be 72 (from conversion) but actual is 64
|
||||
const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim
|
||||
const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim
|
||||
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0);
|
||||
// Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled)
|
||||
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i),
|
||||
@@ -7337,7 +7348,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
// ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer.
|
||||
uint32_t n_rot_max = 0;
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
n_rot_max = std::max(n_rot_max, hparams.n_rot);
|
||||
n_rot_max = std::max(n_rot_max, hparams.n_rot(i));
|
||||
}
|
||||
if (n_rot_max == 0) {
|
||||
n_rot_max = n_rot;
|
||||
@@ -7672,11 +7683,11 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
|
||||
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full);
|
||||
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
||||
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full);
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
|
||||
@@ -7700,6 +7711,9 @@ void llama_model::print_info() const {
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa);
|
||||
LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa);
|
||||
LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa);
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
||||
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
||||
|
||||
+467
-286
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -127,7 +127,6 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU,
|
||||
hparams.expert_weights_norm, // norm_w (route_norm=True)
|
||||
hparams.expert_weights_scale, // scale_w
|
||||
hparams.expert_weights_scale, // w_scale (route_scale=2.826)
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
|
||||
|
||||
llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
|
||||
llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -104,7 +103,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
|
||||
llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -97,7 +96,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
false, hparams.expert_weights_scale,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -90,7 +88,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
hparams.expert_weights_scale, hparams.expert_weights_scale,
|
||||
hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
+13
-7
@@ -1,12 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -129,9 +127,17 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params
|
||||
// feed-forward network
|
||||
if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) {
|
||||
// MoE branch
|
||||
cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, nullptr,
|
||||
model.layers[il].ffn_down_exps, nullptr, hparams.n_expert, hparams.n_expert_used,
|
||||
LLM_FFN_GELU, false, false, 0.0f, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
|
||||
cur = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
nullptr,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
hparams.n_expert, hparams.n_expert_used,
|
||||
LLM_FFN_GELU, false,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
} else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE ||
|
||||
model.arch == LLM_ARCH_JINA_BERT_V3) {
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
|
||||
llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
#include <float.h>
|
||||
|
||||
llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
|
||||
llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * inpL;
|
||||
ggml_tensor * cur;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
const float f_logit_scale = hparams.f_logit_scale;
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
|
||||
llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
const float f_logit_scale = hparams.f_logit_scale;
|
||||
|
||||
|
||||
+4
-5
@@ -1,12 +1,11 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -89,7 +88,7 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
+3
-3
@@ -3,10 +3,10 @@
|
||||
|
||||
|
||||
llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -100,7 +98,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
false, hparams.expert_weights_scale,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
@@ -8,7 +8,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
@@ -216,7 +216,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
hparams.expert_weights_scale, hparams.expert_weights_scale,
|
||||
hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il,
|
||||
nullptr,
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -91,7 +89,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
hparams.expert_weights_scale, hparams.expert_weights_scale,
|
||||
hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
//copied from qwen2
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -103,7 +101,7 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -100,7 +99,7 @@ llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
hparams.expert_weights_scale, hparams.expert_weights_scale,
|
||||
hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
|
||||
llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
template <bool iswa>
|
||||
llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
|
||||
llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
template <bool iswa>
|
||||
llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model),
|
||||
n_embd_head(model.hparams.n_embd_head_k),
|
||||
n_embd_head(model.hparams.n_embd_head_k()),
|
||||
n_embd_altup(model.hparams.n_embd_altup),
|
||||
n_altup(model.hparams.n_altup),
|
||||
i_altup_act(model.hparams.i_altup_act) {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
@@ -128,7 +128,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
hparams.expert_weights_scale, hparams.expert_weights_scale,
|
||||
hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(routed_out, "ffn_moe_out", il);
|
||||
|
||||
+2
-2
@@ -3,10 +3,10 @@
|
||||
|
||||
|
||||
llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
+2
-2
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * pos;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
|
||||
llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -160,7 +159,7 @@ ggml_tensor * llm_build_granite_hybrid::build_layer_ffn(ggml_tensor * cur,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
llm_build_granite::llm_build_granite(
|
||||
const llama_model & model,
|
||||
const llm_graph_params & params)
|
||||
: llm_graph_context(params) {
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -175,7 +174,7 @@ ggml_tensor * llm_build_granite::build_layer_ffn(
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
+4
-4
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -99,7 +99,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_GELU, true,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_chunk_expert = n_expert / hparams.n_group_experts;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -90,7 +88,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il,
|
||||
probs);
|
||||
@@ -106,7 +104,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap
|
||||
nullptr,
|
||||
n_chunk_expert, n_expert_used > n_chunk_expert ? n_chunk_expert : n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il,
|
||||
probs);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -119,8 +119,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU,
|
||||
true, // norm_topk_prob
|
||||
false,
|
||||
0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur_moe, "ffn_moe_out", il);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
+2
-2
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
// JAIS-2 model graph builder
|
||||
// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings
|
||||
llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -76,7 +76,7 @@ llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_para
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "models.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
@@ -103,7 +102,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
// qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
|
||||
// Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot; // config.qk_rope_head_dim
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot(); // config.qk_rope_head_dim
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; // 192 - 64 = 128
|
||||
// Attention scale for MLA
|
||||
const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
|
||||
@@ -341,7 +340,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
||||
hparams.n_expert,
|
||||
hparams.n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
hparams.expert_weights_scale, hparams.expert_weights_scale,
|
||||
hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
+11
-5
@@ -23,17 +23,23 @@ llm_build_lfm2<iswa>::llm_build_lfm2(const llama_model & model, const llm_graph_
|
||||
};
|
||||
auto build_moe_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * {
|
||||
return build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0,
|
||||
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il);
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
hparams.expert_weights_scale,
|
||||
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func),
|
||||
il);
|
||||
};
|
||||
auto build_attn_block = [&model, this](ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
inp_attn_type * inp_attn,
|
||||
int il) -> ggml_tensor * {
|
||||
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
|
||||
const auto n_embd_head = hparams.n_embd_head_v;
|
||||
const auto n_embd_head = hparams.n_embd_head_v();
|
||||
const auto n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
auto * q = build_lora_mm(model.layers[il].wq, cur);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -90,7 +90,7 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
// LLaDA is similar to LLaMA but uses non-causal attention for diffusion
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -134,7 +134,7 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
|
||||
il);
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
template <bool embed>
|
||||
llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
@@ -130,7 +130,7 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user