mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 17:47:40 +02:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aa2d278a11 | |||
| 6c770d16ca | |||
| 8d880ac012 | |||
| 0f1e9d14cc | |||
| 1274fbee9e | |||
| a7b3dee7a5 | |||
| ec947d2b16 | |||
| 0cd4f4720b | |||
| af237f3026 | |||
| 1a5631beaa | |||
| 1dab5f5a44 | |||
| c96f608d98 | |||
| 0842b9b465 | |||
| 59db9a357d |
+2
-2
@@ -2427,11 +2427,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
);
|
||||
}
|
||||
if (split_arg.size() == 1) {
|
||||
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024);
|
||||
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoull(split_arg[0]) * 1024*1024);
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < split_arg.size(); i++) {
|
||||
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024;
|
||||
params.fit_params_target[i] = std::stoull(split_arg[i]) * 1024*1024;
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_FIT_TARGET"));
|
||||
|
||||
@@ -90,7 +90,7 @@ common_peg_arena autoparser::build_parser(const templates_params & inputs) const
|
||||
// pre-register a json-string rule that accepts both quote styles. This must happen
|
||||
// before any call to p.json() so that all JSON parsing inherits the flexible rule.
|
||||
if (tools.format.uses_python_dicts) {
|
||||
p.rule("json-string", [&]() { return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); });
|
||||
p.rule("json-string", p.quoted_string());
|
||||
}
|
||||
|
||||
parser_build_context ctx(p, inputs);
|
||||
|
||||
@@ -507,8 +507,8 @@ common_peg_parser common_chat_peg_builder::python_style_tool_calls(
|
||||
|
||||
common_peg_parser arg_value_parser = eps();
|
||||
auto string_value_parser = choice({
|
||||
literal("\"") + tool_arg_string_value(json_string_content()) + literal("\""),
|
||||
literal("'") + tool_arg_string_value(json_string_content()) + literal("'")
|
||||
literal("\"") + tool_arg_string_value(string_content('"')) + literal("\""),
|
||||
literal("'") + tool_arg_string_value(string_content('\'')) + literal("'")
|
||||
});
|
||||
|
||||
if (is_string_type) {
|
||||
@@ -577,7 +577,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
|
||||
if (!call_id_key.empty()) {
|
||||
auto id_parser = atomic(
|
||||
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\"")
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\"")
|
||||
);
|
||||
inner_fields.push_back(optional(id_parser + space() + optional(literal(",") + space())));
|
||||
}
|
||||
@@ -586,7 +586,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
|
||||
auto gen_id_parser = atomic(
|
||||
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
@@ -675,7 +675,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
|
||||
if (id_spec.first.empty()) {
|
||||
auto id_parser = atomic(
|
||||
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\"")
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\"")
|
||||
);
|
||||
tool_parser_body = tool_parser_body + optional(id_parser + space() + literal(",") + space());
|
||||
}
|
||||
@@ -687,7 +687,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
|
||||
auto gen_id_parser = atomic(
|
||||
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
@@ -736,7 +736,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
|
||||
id_parser = atomic(
|
||||
literal("\"" + call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
@@ -747,7 +747,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
|
||||
gen_id_parser = atomic(
|
||||
literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() +
|
||||
choice({
|
||||
literal("\"") + tool_id(json_string_content()) + literal("\""),
|
||||
literal("\"") + tool_id(string_content('"')) + literal("\""),
|
||||
tool_id(json_number())
|
||||
})
|
||||
);
|
||||
|
||||
+2
-2
@@ -1620,8 +1620,8 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
build_chat_peg_parser([](common_chat_peg_builder & p) { return p.content(p.rest()) + p.end(); }) :
|
||||
src_parser;
|
||||
|
||||
if (src_parser.empty()) {
|
||||
LOG_WRN("No parser definition detected, assuming pure content parser.");
|
||||
if (src_parser.empty()) {
|
||||
LOG_DBG("No parser definition detected, assuming pure content parser.");
|
||||
}
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
|
||||
|
||||
@@ -790,7 +790,7 @@ public:
|
||||
} else if (target.is_array()) {
|
||||
size_t sel_index;
|
||||
try {
|
||||
sel_index = std::stoul(sel);
|
||||
sel_index = std::stoull(sel);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
sel_index = target.size();
|
||||
}
|
||||
|
||||
+119
-129
@@ -658,7 +658,7 @@ struct parser_executor {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
|
||||
}
|
||||
|
||||
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos) {
|
||||
static common_peg_parse_result handle_escape_sequence(common_peg_parse_context & ctx, size_t start, size_t & pos, const char delimiter) {
|
||||
++pos; // consume '\'
|
||||
if (pos >= ctx.input.size()) {
|
||||
if (!ctx.is_lenient()) {
|
||||
@@ -667,23 +667,14 @@ struct parser_executor {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start, pos);
|
||||
}
|
||||
|
||||
switch (ctx.input[pos]) {
|
||||
case '"':
|
||||
case '\'':
|
||||
case '\\':
|
||||
case '/':
|
||||
case 'b':
|
||||
case 'f':
|
||||
case 'n':
|
||||
case 'r':
|
||||
case 't':
|
||||
++pos;
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
|
||||
case 'u':
|
||||
return handle_unicode_escape(ctx, start, pos);
|
||||
default:
|
||||
// Invalid escape sequence
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
|
||||
char c = ctx.input[pos];
|
||||
if (c == delimiter || c == '\\' || c == '/' || c == 'b' || c == 'f' || c == 'n' || c == 'r' || c == 't') {
|
||||
++pos;
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
|
||||
} else if (c == 'u') {
|
||||
return handle_unicode_escape(ctx, start, pos);
|
||||
} else {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -704,62 +695,20 @@ struct parser_executor {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start, pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_json_string_parser & /* p */) {
|
||||
common_peg_parse_result operator()(const common_peg_string_parser & p) {
|
||||
auto pos = start_pos;
|
||||
|
||||
// Parse string content (without quotes)
|
||||
while (pos < ctx.input.size()) {
|
||||
char c = ctx.input[pos];
|
||||
|
||||
if (c == '"') {
|
||||
// Found closing quote - success (don't consume it)
|
||||
if (c == p.delimiter) {
|
||||
// Found closing delimiter - success (don't consume it)
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
|
||||
}
|
||||
|
||||
if (c == '\\') {
|
||||
auto result = handle_escape_sequence(ctx, start_pos, pos);
|
||||
if (!result.success()) {
|
||||
return result;
|
||||
}
|
||||
} else {
|
||||
auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos);
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INCOMPLETE) {
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
}
|
||||
|
||||
if (utf8_result.status == utf8_parse_result::INVALID) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos);
|
||||
}
|
||||
|
||||
pos += utf8_result.bytes_consumed;
|
||||
}
|
||||
}
|
||||
|
||||
// Reached end without finding closing quote
|
||||
if (!ctx.is_lenient()) {
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos);
|
||||
}
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_python_dict_string_parser & /* p */) {
|
||||
auto pos = start_pos;
|
||||
|
||||
// Parse string content (without quotes)
|
||||
while (pos < ctx.input.size()) {
|
||||
char c = ctx.input[pos];
|
||||
|
||||
if (c == '\'') {
|
||||
// Found closing quote - success (don't consume it)
|
||||
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos);
|
||||
}
|
||||
|
||||
if (c == '\\') {
|
||||
auto result = handle_escape_sequence(ctx, start_pos, pos);
|
||||
auto result = handle_escape_sequence(ctx, start_pos, pos, p.delimiter);
|
||||
if (!result.success()) {
|
||||
return result;
|
||||
}
|
||||
@@ -988,8 +937,7 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_ref_parser> ||
|
||||
std::is_same_v<T, common_peg_until_parser> ||
|
||||
std::is_same_v<T, common_peg_literal_parser> ||
|
||||
std::is_same_v<T, common_peg_json_string_parser> ||
|
||||
std::is_same_v<T, common_peg_python_dict_string_parser> ||
|
||||
std::is_same_v<T, common_peg_string_parser> ||
|
||||
std::is_same_v<T, common_peg_chars_parser> ||
|
||||
std::is_same_v<T, common_peg_any_parser> ||
|
||||
std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1065,10 +1013,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", unbounded)";
|
||||
}
|
||||
return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
return "JsonString()";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
return "PythonDictString()";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
|
||||
return "String(" + std::string(1, p.delimiter) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
|
||||
return "Until(" + string_join(p.delimiters, " | ") + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
@@ -1281,47 +1227,25 @@ common_peg_arena common_peg_parser_builder::build() {
|
||||
|
||||
// String primitives
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_string_content() {
|
||||
return wrap(arena_.add_parser(common_peg_json_string_parser{}));
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::single_quoted_string_content() {
|
||||
return wrap(arena_.add_parser(common_peg_python_dict_string_parser{}));
|
||||
common_peg_parser common_peg_parser_builder::string_content(char delimiter) {
|
||||
return wrap(arena_.add_parser(common_peg_string_parser{delimiter}));
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::double_quoted_string() {
|
||||
return rule("dq-string",
|
||||
[this]() { return sequence({ literal("\""), json_string_content(), literal("\""), space() }); });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::single_quoted_string() {
|
||||
return rule("sq-string",
|
||||
[this]() { return sequence({ literal("'"), single_quoted_string_content(), literal("'"), space() }); });
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::flexible_string() {
|
||||
return rule("flexible-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
|
||||
}
|
||||
|
||||
// Generic helpers for object/array structure
|
||||
|
||||
common_peg_parser common_peg_parser_builder::generic_object(const std::string & name,
|
||||
const common_peg_parser & string_parser,
|
||||
const common_peg_parser & value_parser) {
|
||||
return rule(name, [this, string_parser, value_parser]() {
|
||||
auto ws = space();
|
||||
auto member = sequence({ string_parser, ws, literal(":"), ws, value_parser });
|
||||
auto members = sequence({ member, zero_or_more(sequence({ ws, literal(","), ws, member })) });
|
||||
return sequence({ literal("{"), ws, choice({ literal("}"), sequence({ members, ws, literal("}") }) }) });
|
||||
return rule("double-quoted-string", [this]() {
|
||||
return sequence({literal("\""), string_content('"'), literal("\""), space()});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::generic_array(const std::string & name,
|
||||
const common_peg_parser & value_parser) {
|
||||
return rule(name, [this, value_parser]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({ value_parser, zero_or_more(sequence({ literal(","), ws, value_parser })) });
|
||||
return sequence({ literal("["), ws, choice({ literal("]"), sequence({ elements, ws, literal("]") }) }) });
|
||||
common_peg_parser common_peg_parser_builder::single_quoted_string() {
|
||||
return rule("single-quoted-string", [this]() {
|
||||
return sequence({literal("'"), string_content('\''), literal("'"), space()});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::quoted_string() {
|
||||
return rule("quoted-string", [this]() {
|
||||
return choice({double_quoted_string(), single_quoted_string()});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1344,7 +1268,7 @@ common_peg_parser common_peg_parser_builder::json_number() {
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_string() {
|
||||
return rule("json-string", [this]() {
|
||||
return sequence({literal("\""), json_string_content(), literal("\""), space()});
|
||||
return sequence({literal("\""), string_content('"'), literal("\""), space()});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1361,11 +1285,36 @@ common_peg_parser common_peg_parser_builder::json_null() {
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_object() {
|
||||
return generic_object("json-object", json_string(), json());
|
||||
return rule("json-object", [this]() {
|
||||
auto ws = space();
|
||||
auto member = sequence({json_string(), ws, literal(":"), ws, json()});
|
||||
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
|
||||
return sequence({
|
||||
literal("{"),
|
||||
ws,
|
||||
choice({
|
||||
literal("}"),
|
||||
sequence({members, ws, literal("}")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json_array() {
|
||||
return generic_array("json-array", json());
|
||||
return rule("json-array", [this]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
|
||||
return sequence({
|
||||
literal("["),
|
||||
ws,
|
||||
choice({
|
||||
literal("]"),
|
||||
sequence({elements, ws, literal("]")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::json() {
|
||||
@@ -1382,7 +1331,9 @@ common_peg_parser common_peg_parser_builder::json() {
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_string() {
|
||||
return rule("python-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); });
|
||||
return rule("python-string", [this]() {
|
||||
return choice({double_quoted_string(), single_quoted_string()});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_number() {
|
||||
@@ -1390,24 +1341,63 @@ common_peg_parser common_peg_parser_builder::python_number() {
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_bool() {
|
||||
return rule("python-bool", [this]() { return sequence({ choice({ literal("True"), literal("False") }), space() }); });
|
||||
return rule("python-bool", [this]() {
|
||||
return sequence({
|
||||
choice({literal("True"), literal("False")}),
|
||||
space()
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_null() {
|
||||
return rule("python-none", [this]() { return sequence({ literal("None"), space() }); });
|
||||
return rule("python-none", [this]() {
|
||||
return sequence({literal("None"), space()});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_dict() {
|
||||
return generic_object("python-dict", python_string(), python_value());
|
||||
return rule("python-dict", [this]() {
|
||||
auto ws = space();
|
||||
auto member = sequence({python_string(), ws, literal(":"), ws, python_value()});
|
||||
auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))});
|
||||
return sequence({
|
||||
literal("{"),
|
||||
ws,
|
||||
choice({
|
||||
literal("}"),
|
||||
sequence({members, ws, literal("}")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_array() {
|
||||
return generic_array("python-array", python_value());
|
||||
return rule("python-array", [this]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({python_value(), zero_or_more(sequence({literal(","), ws, python_value()}))});
|
||||
return sequence({
|
||||
literal("["),
|
||||
ws,
|
||||
choice({
|
||||
literal("]"),
|
||||
sequence({elements, ws, literal("]")})
|
||||
}),
|
||||
ws
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::python_value() {
|
||||
return rule("python-value", [this]() {
|
||||
return choice({ python_dict(), python_array(), python_string(), python_number(), python_bool(), python_null() });
|
||||
return choice({
|
||||
python_dict(),
|
||||
python_array(),
|
||||
python_string(),
|
||||
python_number(),
|
||||
python_bool(),
|
||||
python_null()
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1528,8 +1518,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_chars_parser> ||
|
||||
std::is_same_v<T, common_peg_space_parser> ||
|
||||
std::is_same_v<T, common_peg_any_parser> ||
|
||||
std::is_same_v<T, common_peg_json_string_parser> ||
|
||||
std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
std::is_same_v<T, common_peg_string_parser>) {
|
||||
// These parsers do not have any children
|
||||
} else if constexpr (std::is_same_v<T, common_peg_sequence_parser>) {
|
||||
for (auto child : p.children) {
|
||||
@@ -1665,10 +1654,9 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return result + "{" + std::to_string(p.min_count) + "}";
|
||||
}
|
||||
return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
|
||||
const std::string delim(1, p.delimiter);
|
||||
return R"(( [^)" + delim + R"(\\] | "\\" ( [)" + delim + R"(\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
|
||||
if (p.delimiters.empty()) {
|
||||
return ".*";
|
||||
@@ -1798,10 +1786,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
{"min_count", p.min_count},
|
||||
{"max_count", p.max_count}
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_json_string_parser>) {
|
||||
return json{{"type", "json_string"}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_python_dict_string_parser>) {
|
||||
return json{{ "type", "python_dict_string" }};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_string_parser>) {
|
||||
return json{{"type", "string"}, {"delimiter", std::string(1, p.delimiter)}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_until_parser>) {
|
||||
return json{{"type", "until"}, {"delimiters", p.delimiters}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
@@ -1928,11 +1914,15 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
if (type == "json_string") {
|
||||
return common_peg_json_string_parser{};
|
||||
}
|
||||
if (type == "python_dict_string") {
|
||||
return common_peg_python_dict_string_parser{};
|
||||
if (type == "string") {
|
||||
if (!j.contains("delimiter")) {
|
||||
throw std::runtime_error("string parser missing delimiter field.");
|
||||
}
|
||||
std::string delimiter = j["delimiter"];
|
||||
if (delimiter.empty()) {
|
||||
throw std::runtime_error("string parser delimiter is empty.");
|
||||
}
|
||||
return common_peg_string_parser{delimiter[0]};
|
||||
}
|
||||
if (type == "until") {
|
||||
if (!j.contains("delimiters") || !j["delimiters"].is_array()) {
|
||||
|
||||
+7
-14
@@ -231,8 +231,9 @@ struct common_peg_chars_parser {
|
||||
int max_count; // -1 for unbounded
|
||||
};
|
||||
|
||||
struct common_peg_json_string_parser {};
|
||||
struct common_peg_python_dict_string_parser {};
|
||||
struct common_peg_string_parser {
|
||||
char delimiter;
|
||||
};
|
||||
|
||||
struct common_peg_until_parser {
|
||||
std::vector<std::string> delimiters;
|
||||
@@ -280,8 +281,7 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_any_parser,
|
||||
common_peg_space_parser,
|
||||
common_peg_chars_parser,
|
||||
common_peg_json_string_parser,
|
||||
common_peg_python_dict_string_parser,
|
||||
common_peg_string_parser,
|
||||
common_peg_until_parser,
|
||||
common_peg_schema_parser,
|
||||
common_peg_rule_parser,
|
||||
@@ -340,10 +340,6 @@ class common_peg_parser_builder {
|
||||
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
|
||||
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
|
||||
|
||||
// Generic helpers for building object/array structures with configurable string/value parsers.
|
||||
common_peg_parser generic_object(const std::string & name, const common_peg_parser & string_parser, const common_peg_parser & value_parser);
|
||||
common_peg_parser generic_array(const std::string & name, const common_peg_parser & value_parser);
|
||||
|
||||
public:
|
||||
common_peg_parser_builder();
|
||||
|
||||
@@ -444,13 +440,10 @@ class common_peg_parser_builder {
|
||||
common_peg_parser single_quoted_string();
|
||||
|
||||
// Matches a string that accepts both double-quoted and single-quoted styles.
|
||||
common_peg_parser flexible_string();
|
||||
common_peg_parser quoted_string();
|
||||
|
||||
// Matches double-quoted string content without the surrounding quotes.
|
||||
common_peg_parser json_string_content();
|
||||
|
||||
// Matches single-quoted string content without the surrounding quotes.
|
||||
common_peg_parser single_quoted_string_content();
|
||||
// Matches string content without the surrounding delimiter.
|
||||
common_peg_parser string_content(char delimiter);
|
||||
|
||||
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
|
||||
// value -> object | array | string | number | true | false | null
|
||||
|
||||
+7
-1
@@ -599,7 +599,13 @@ If KleidiAI is enabled, the output will contain a line similar to:
|
||||
```
|
||||
load_tensors: CPU_KLEIDIAI model buffer size = 3474.00 MiB
|
||||
```
|
||||
KleidiAI's microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm and SME. llama.cpp selects the most efficient kernel based on runtime CPU feature detection. However, on platforms that support SME, you must manually enable SME microkernels by setting the environment variable `GGML_KLEIDIAI_SME=1`.
|
||||
KleidiAI’s microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm, SVE, and SME. Llama.cpp selects the most efficient kernels at runtime based on detected CPU capabilities.
|
||||
On CPUs that support SME, SME microkernels are enabled automatically using runtime detection.
|
||||
The environment variable GGML_KLEIDIAI_SME can be used to control SME behavior:
|
||||
- Not set: enable SME automatically if supported and detected.
|
||||
- 0: disable SME.
|
||||
- <n> > 0: enable SME and assume <n> available SME units (override auto detection).
|
||||
If SME is not supported by the CPU, SME microkernels are always disabled.
|
||||
|
||||
Depending on your build target, other higher priority backends may be enabled by default. To ensure the CPU backend is used, you must disable the higher priority backends either at compile time, e.g. -DGGML_METAL=OFF, or during run-time using the command line option `--device none`.
|
||||
|
||||
|
||||
+2
-2
@@ -47,7 +47,7 @@ Legend:
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
@@ -76,7 +76,7 @@ Legend:
|
||||
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
|
||||
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
||||
+1689
-6836
File diff suppressed because it is too large
Load Diff
@@ -633,7 +633,7 @@ class SchemaConverter:
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
items = schema.get('items') or schema['prefixItems']
|
||||
items = schema.get('items', schema.get('prefixItems'))
|
||||
if isinstance(items, list):
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
|
||||
@@ -202,8 +202,9 @@
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -520,7 +520,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .required_cpu = */ CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
@@ -631,7 +631,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .required_cpu = */ CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
@@ -801,7 +801,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .required_cpu = */ CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+1202
-3
File diff suppressed because it is too large
Load Diff
@@ -28,13 +28,17 @@ template <int K, int N> struct block {
|
||||
// control size
|
||||
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
||||
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
||||
static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding");
|
||||
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
||||
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
||||
static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding");
|
||||
|
||||
using block_q4_0x4 = block<4, 4>;
|
||||
using block_q4_0x8 = block<4, 8>;
|
||||
using block_q4_0x16 = block<4, 16>;
|
||||
using block_q8_0x4 = block<8, 4>;
|
||||
using block_q8_0x8 = block<8, 8>;
|
||||
using block_q8_0x16 = block<8, 16>;
|
||||
|
||||
struct block_q4_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
@@ -44,7 +48,14 @@ struct block_q4_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
||||
struct block_q4_Kx16 {
|
||||
ggml_half d[16]; // super-block scale for quantized scales
|
||||
ggml_half dmin[16]; // super-block scale for quantized mins
|
||||
uint8_t scales[192]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[2048]; // 4--bit quants
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding");
|
||||
struct block_q2_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
@@ -53,6 +64,13 @@ struct block_q2_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
||||
struct block_q2_Kx16 {
|
||||
ggml_half d[16]; // Super-block scale for quantized scales
|
||||
ggml_half dmin[16]; // Super-block scale for quantized mins
|
||||
uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks)
|
||||
uint8_t qs[1024]; // Data (16 cols * 64 bytes per block)
|
||||
};
|
||||
static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding");
|
||||
|
||||
struct block_q5_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
@@ -97,6 +115,12 @@ struct block_iq4_nlx8 {
|
||||
|
||||
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
|
||||
|
||||
struct block_iq4_nlx16 {
|
||||
ggml_half d[16]; // deltas for 16 iq4_nl blocks
|
||||
uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding");
|
||||
struct block_mxfp4x4 {
|
||||
uint8_t e[4];
|
||||
uint8_t qs[QK_MXFP4 * 2];
|
||||
@@ -109,7 +133,6 @@ struct block_mxfp4x8 {
|
||||
};
|
||||
static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
|
||||
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -132,6 +155,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -146,10 +171,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#endif
|
||||
|
||||
// Native implementations
|
||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
@@ -170,6 +207,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -184,10 +223,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
|
||||
@@ -75,6 +75,10 @@ struct ggml_metal {
|
||||
// abort ggml_metal_graph_compute if callback returns true
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_data;
|
||||
|
||||
// error state - set when a command buffer fails during synchronize
|
||||
// once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated
|
||||
bool has_error;
|
||||
};
|
||||
|
||||
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||
@@ -158,6 +162,8 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||
res->capture_started = false;
|
||||
res->capture_scope = nil;
|
||||
|
||||
res->has_error = false;
|
||||
|
||||
res->gf = nil;
|
||||
res->encode_async = nil;
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
@@ -246,7 +252,8 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
ctx->has_error = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,7 +269,15 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
// release this and all remaining command buffers before returning
|
||||
for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) {
|
||||
[ctx->cmd_bufs_ext[j] release];
|
||||
}
|
||||
[ctx->cmd_bufs_ext removeAllObjects];
|
||||
|
||||
ctx->has_error = true;
|
||||
return;
|
||||
}
|
||||
|
||||
[cmd_buf release];
|
||||
@@ -414,6 +429,11 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con
|
||||
}
|
||||
|
||||
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
|
||||
if (ctx->has_error) {
|
||||
GGML_LOG_ERROR("%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\n", __func__);
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
|
||||
// number of nodes encoded by the main thread (empirically determined)
|
||||
const int n_main = MAX(64, 0.1*gf->n_nodes);
|
||||
|
||||
|
||||
@@ -42,11 +42,20 @@
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
||||
|
||||
// Matrix-vector multiplication parameters
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
|
||||
// Must be multiple of 4 to work with vectorized paths, and must divide
|
||||
// mul_mat_vec wg size
|
||||
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
|
||||
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
|
||||
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
|
||||
|
||||
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
|
||||
// Requires at least two (and multiple of 2) k-quant blocks per tile
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
|
||||
|
||||
// default size for legacy matrix multiplication
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
@@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key {
|
||||
bool src_overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
|
||||
src_overlap == other.src_overlap;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -749,6 +759,36 @@ class ggml_webgpu_shader_lib {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "mul_mat_vec";
|
||||
|
||||
// src0 type (matrix row)
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC0_INNER_TYPE=f32");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
{
|
||||
// Quantized types: use helpers but accumulate in f16
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
std::string src0_name = src0_traits->type_name;
|
||||
std::string type_upper = src0_name;
|
||||
variant += "_" + src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
|
||||
// For fast path we always dequantize from f16 inside the shader
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// src1 type (vector)
|
||||
switch (context.src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
@@ -763,39 +803,21 @@ class ggml_webgpu_shader_lib {
|
||||
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
|
||||
}
|
||||
|
||||
// src0 type (matrix row)
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC0_INNER_TYPE=f32");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
break;
|
||||
default:
|
||||
{
|
||||
// Quantized types: use helpers but accumulate in f16
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
std::string src0_name = src0_traits->type_name;
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
|
||||
// For fast path we always dequantize from f16 inside the shader
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// VEC/SCALAR controls
|
||||
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
||||
|
||||
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
|
||||
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
|
||||
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
||||
|
||||
if (key.src0_type >= GGML_TYPE_Q2_K) {
|
||||
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
||||
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
||||
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
|
||||
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
||||
@@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_binary_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
.overlap = context.overlap,
|
||||
.type = context.dst->type,
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
.overlap = context.overlap,
|
||||
.src_overlap = context.src_overlap,
|
||||
};
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-webgpu-shader-lib.hpp"
|
||||
#include "pre_wgsl.hpp"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# include <emscripten/emscripten.h>
|
||||
@@ -20,12 +19,18 @@
|
||||
#include <condition_variable>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
# include <iomanip>
|
||||
#endif
|
||||
#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
|
||||
# include <iostream>
|
||||
#endif
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
|
||||
@@ -70,22 +75,21 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
||||
#endif // GGML_WEBGPU_CPU_PROFILE
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
|
||||
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
|
||||
# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
|
||||
#endif
|
||||
|
||||
/* Constants */
|
||||
|
||||
#define WEBGPU_NUM_PARAM_BUFS 48u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
|
||||
#define WEBGPU_NUM_PARAM_BUFS 96u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
|
||||
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
|
||||
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
|
||||
// parameter buffer pool
|
||||
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
|
||||
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
|
||||
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
||||
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
|
||||
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
|
||||
// For operations which process a row in parallel, this seems like a reasonable
|
||||
// default
|
||||
@@ -118,14 +122,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||
wgpu::BufferUsage usage,
|
||||
const char * label);
|
||||
|
||||
struct webgpu_pool_bufs {
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
};
|
||||
|
||||
// Holds a pool of parameter buffers for WebGPU operations
|
||||
struct webgpu_buf_pool {
|
||||
std::vector<webgpu_pool_bufs> free;
|
||||
std::vector<wgpu::Buffer> free;
|
||||
|
||||
// The pool must be synchronized because
|
||||
// 1. The memset pool is shared globally by every ggml buffer,
|
||||
@@ -138,7 +137,6 @@ struct webgpu_buf_pool {
|
||||
size_t cur_pool_size;
|
||||
size_t max_pool_size;
|
||||
wgpu::Device device;
|
||||
wgpu::BufferUsage host_buf_usage;
|
||||
wgpu::BufferUsage dev_buf_usage;
|
||||
size_t buf_size;
|
||||
bool should_grow;
|
||||
@@ -147,53 +145,47 @@ struct webgpu_buf_pool {
|
||||
int num_bufs,
|
||||
size_t buf_size,
|
||||
wgpu::BufferUsage dev_buf_usage,
|
||||
wgpu::BufferUsage host_buf_usage,
|
||||
bool should_grow = false,
|
||||
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
|
||||
this->max_pool_size = max_pool_size;
|
||||
this->cur_pool_size = num_bufs;
|
||||
this->device = device;
|
||||
this->host_buf_usage = host_buf_usage;
|
||||
this->dev_buf_usage = dev_buf_usage;
|
||||
this->buf_size = buf_size;
|
||||
this->should_grow = should_grow;
|
||||
this->max_pool_size = max_pool_size;
|
||||
this->cur_pool_size = num_bufs;
|
||||
this->device = device;
|
||||
this->dev_buf_usage = dev_buf_usage;
|
||||
this->buf_size = buf_size;
|
||||
this->should_grow = should_grow;
|
||||
for (int i = 0; i < num_bufs; i++) {
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
|
||||
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
||||
free.push_back({ host_buf, dev_buf });
|
||||
free.push_back(dev_buf);
|
||||
}
|
||||
}
|
||||
|
||||
webgpu_pool_bufs alloc_bufs() {
|
||||
wgpu::Buffer alloc_bufs() {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
if (!free.empty()) {
|
||||
webgpu_pool_bufs bufs = free.back();
|
||||
wgpu::Buffer buf = free.back();
|
||||
free.pop_back();
|
||||
return bufs;
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Try growing the pool if no free buffers
|
||||
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
|
||||
cur_pool_size++;
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
|
||||
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
||||
|
||||
if (!(host_buf && dev_buf)) {
|
||||
if (!dev_buf) {
|
||||
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
|
||||
}
|
||||
return webgpu_pool_bufs{ host_buf, dev_buf };
|
||||
return dev_buf;
|
||||
}
|
||||
cv.wait(lock, [this] { return !free.empty(); });
|
||||
webgpu_pool_bufs bufs = free.back();
|
||||
wgpu::Buffer buf = free.back();
|
||||
free.pop_back();
|
||||
return bufs;
|
||||
return buf;
|
||||
}
|
||||
|
||||
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
|
||||
void free_bufs(std::vector<wgpu::Buffer> bufs) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
free.insert(free.end(), bufs.begin(), bufs.end());
|
||||
cv.notify_all();
|
||||
@@ -201,12 +193,9 @@ struct webgpu_buf_pool {
|
||||
|
||||
void cleanup() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for (auto & bufs : free) {
|
||||
if (bufs.host_buf) {
|
||||
bufs.host_buf.Destroy();
|
||||
}
|
||||
if (bufs.dev_buf) {
|
||||
bufs.dev_buf.Destroy();
|
||||
for (auto & buf : free) {
|
||||
if (buf) {
|
||||
buf.Destroy();
|
||||
}
|
||||
}
|
||||
free.clear();
|
||||
@@ -280,10 +269,9 @@ struct webgpu_gpu_profile_buf_pool {
|
||||
#endif
|
||||
|
||||
struct webgpu_command {
|
||||
uint32_t num_kernels;
|
||||
wgpu::CommandBuffer commands;
|
||||
std::vector<webgpu_pool_bufs> params_bufs;
|
||||
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
|
||||
uint32_t num_kernels;
|
||||
wgpu::CommandBuffer commands;
|
||||
std::vector<wgpu::Buffer> params_bufs;
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
webgpu_gpu_profile_bufs timestamp_query_bufs;
|
||||
std::string pipeline_name;
|
||||
@@ -358,6 +346,13 @@ struct webgpu_global_context_struct {
|
||||
|
||||
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
|
||||
|
||||
struct webgpu_submission {
|
||||
wgpu::FutureWaitInfo submit_done;
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
std::vector<wgpu::FutureWaitInfo> profile_futures;
|
||||
#endif
|
||||
};
|
||||
|
||||
// All the base objects needed to run operations on a WebGPU device
|
||||
struct webgpu_context_struct {
|
||||
// Points to global instances owned by ggml_backend_webgpu_reg_context
|
||||
@@ -366,7 +361,8 @@ struct webgpu_context_struct {
|
||||
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
|
||||
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
wgpu::Buffer set_rows_dev_error_buf;
|
||||
wgpu::Buffer set_rows_host_error_buf;
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
@@ -458,67 +454,105 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||
/** End WebGPU object initializations */
|
||||
|
||||
/** WebGPU Actions */
|
||||
static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||
|
||||
static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
|
||||
switch (status) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
return true;
|
||||
case wgpu::WaitStatus::TimedOut:
|
||||
if (allow_timeout) {
|
||||
return false;
|
||||
}
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
|
||||
return false;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
return false;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||
futures.erase(std::remove_if(futures.begin(), futures.end(),
|
||||
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
|
||||
futures.end());
|
||||
}
|
||||
|
||||
// Wait for the queue to finish processing all submitted work
|
||||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<wgpu::FutureWaitInfo> & futures,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first.
|
||||
static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
|
||||
std::vector<wgpu::FutureWaitInfo> & futures,
|
||||
bool block) {
|
||||
if (futures.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
||||
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
|
||||
if (waitStatus == wgpu::WaitStatus::Error) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
if (block) {
|
||||
while (!futures.empty()) {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
|
||||
ggml_backend_webgpu_erase_completed_futures(futures);
|
||||
}
|
||||
}
|
||||
if (futures[0].completed) {
|
||||
futures.erase(futures.begin());
|
||||
} else {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
|
||||
ggml_backend_webgpu_erase_completed_futures(futures);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Wait for the queue to finish processing all submitted work
|
||||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<webgpu_submission> & subs,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first.
|
||||
if (subs.empty()) {
|
||||
return;
|
||||
}
|
||||
while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
|
||||
#endif
|
||||
subs.erase(subs.begin());
|
||||
}
|
||||
}
|
||||
|
||||
if (futures.empty()) {
|
||||
if (subs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (block) {
|
||||
while (!futures.empty()) {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
switch (waitStatus) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
|
||||
erase_completed(futures);
|
||||
break;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
break;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
break;
|
||||
for (auto & sub : subs) {
|
||||
while (!sub.submit_done.completed) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
|
||||
ggml_backend_webgpu_handle_wait_status(waitStatus);
|
||||
}
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
|
||||
#endif
|
||||
}
|
||||
subs.clear();
|
||||
} else {
|
||||
// Poll once and return
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
switch (waitStatus) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
|
||||
erase_completed(futures);
|
||||
break;
|
||||
case wgpu::WaitStatus::TimedOut:
|
||||
break;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
break;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
break;
|
||||
// Poll each submit future once and remove completed submissions.
|
||||
for (auto sub = subs.begin(); sub != subs.end();) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
|
||||
ggml_backend_webgpu_handle_wait_status(waitStatus, true);
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
|
||||
if (sub->submit_done.completed && sub->profile_futures.empty()) {
|
||||
#else
|
||||
if (sub->submit_done.completed) {
|
||||
#endif
|
||||
sub = subs.erase(sub);
|
||||
} else {
|
||||
++sub;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -554,14 +588,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
webgpu_global_context ctx,
|
||||
std::vector<webgpu_command> commands,
|
||||
webgpu_buf_pool & param_buf_pool,
|
||||
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
|
||||
static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
|
||||
std::vector<webgpu_command> & commands,
|
||||
webgpu_buf_pool & param_buf_pool) {
|
||||
std::vector<wgpu::CommandBuffer> command_buffers;
|
||||
std::vector<webgpu_pool_bufs> params_bufs;
|
||||
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
|
||||
std::vector<wgpu::Buffer> params_bufs;
|
||||
webgpu_submission submission;
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
|
||||
#endif
|
||||
@@ -569,14 +601,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
for (const auto & command : commands) {
|
||||
command_buffers.push_back(command.commands);
|
||||
params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
|
||||
if (command.set_rows_error_bufs) {
|
||||
set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
|
||||
}
|
||||
}
|
||||
ctx->queue.Submit(command_buffers.size(), command_buffers.data());
|
||||
|
||||
std::vector<wgpu::FutureWaitInfo> futures;
|
||||
|
||||
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
@@ -586,27 +613,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
// Free the staged buffers
|
||||
param_buf_pool.free_bufs(params_bufs);
|
||||
});
|
||||
futures.push_back({ p_f });
|
||||
|
||||
for (const auto & bufs : set_rows_error_bufs) {
|
||||
wgpu::Future f = bufs.host_buf.MapAsync(
|
||||
wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
||||
[set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
|
||||
} else {
|
||||
const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
|
||||
if (*error_data) {
|
||||
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
||||
}
|
||||
// We can't unmap in here due to WebGPU reentrancy limitations.
|
||||
if (set_rows_error_buf_pool) {
|
||||
set_rows_error_buf_pool->free_bufs({ bufs });
|
||||
}
|
||||
}
|
||||
});
|
||||
futures.push_back({ f });
|
||||
}
|
||||
submission.submit_done = { p_f };
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
for (const auto & command : commands) {
|
||||
@@ -623,14 +630,14 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
// WebGPU timestamps are in ns; convert to ms
|
||||
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
|
||||
ctx->shader_gpu_time_ms[label] += elapsed_ms;
|
||||
// We can't unmap in here due to WebGPU reentrancy limitations.
|
||||
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
|
||||
}
|
||||
// We can't unmap in here due to WebGPU reentrancy limitations.
|
||||
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
|
||||
});
|
||||
futures.push_back({ f });
|
||||
submission.profile_futures.push_back({ f });
|
||||
}
|
||||
#endif
|
||||
return futures;
|
||||
return submission;
|
||||
}
|
||||
|
||||
static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
@@ -639,32 +646,21 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
const std::vector<webgpu_pipeline> & pipelines,
|
||||
const std::vector<std::vector<uint32_t>> & params_list,
|
||||
const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
|
||||
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
|
||||
const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
|
||||
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
|
||||
GGML_ASSERT(pipelines.size() == params_list.size());
|
||||
GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
|
||||
GGML_ASSERT(pipelines.size() == workgroups_list.size());
|
||||
|
||||
std::vector<webgpu_pool_bufs> params_bufs_list;
|
||||
std::vector<wgpu::BindGroup> bind_groups;
|
||||
std::vector<wgpu::Buffer> params_bufs_list;
|
||||
std::vector<wgpu::BindGroup> bind_groups;
|
||||
|
||||
for (size_t i = 0; i < pipelines.size(); i++) {
|
||||
webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
|
||||
|
||||
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
|
||||
params_bufs.host_buf.GetSize());
|
||||
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
|
||||
for (size_t j = 0; j < params_list[i].size(); j++) {
|
||||
_params[j] = params_list[i][j];
|
||||
}
|
||||
params_bufs.host_buf.Unmap();
|
||||
wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
|
||||
uint32_t params_binding_num = entries.size();
|
||||
entries.push_back({ .binding = params_binding_num,
|
||||
.buffer = params_bufs.dev_buf,
|
||||
.offset = 0,
|
||||
.size = params_bufs.dev_buf.GetSize() });
|
||||
entries.push_back(
|
||||
{ .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
|
||||
|
||||
wgpu::BindGroupDescriptor bind_group_desc;
|
||||
bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
|
||||
@@ -677,15 +673,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
}
|
||||
|
||||
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
||||
for (const auto & params_bufs : params_bufs_list) {
|
||||
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
|
||||
}
|
||||
|
||||
// If there are SET_ROWS operations in this submission, copy their error
|
||||
// buffers to the host.
|
||||
if (set_rows_error_bufs) {
|
||||
encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
|
||||
set_rows_error_bufs->host_buf.GetSize());
|
||||
for (size_t i = 0; i < params_bufs_list.size(); i++) {
|
||||
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
@@ -718,7 +707,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
webgpu_command result = {};
|
||||
result.commands = commands;
|
||||
result.params_bufs = params_bufs_list;
|
||||
result.set_rows_error_bufs = set_rows_error_bufs;
|
||||
result.num_kernels = pipelines.size();
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
result.timestamp_query_bufs = ts_bufs;
|
||||
@@ -734,13 +722,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &
|
||||
std::vector<uint32_t> params,
|
||||
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
||||
uint32_t wg_x,
|
||||
uint32_t wg_y = 1,
|
||||
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
|
||||
uint32_t wg_y = 1) {
|
||||
return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
|
||||
{
|
||||
pipeline
|
||||
},
|
||||
{ params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
|
||||
{ std::move(params) }, { std::move(bind_group_entries) },
|
||||
{ { wg_x, wg_y } });
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
||||
@@ -757,8 +745,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
||||
|
||||
webgpu_command command =
|
||||
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
|
||||
auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
|
||||
ggml_backend_webgpu_wait(ctx, futures);
|
||||
std::vector<webgpu_command> commands = { command };
|
||||
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
|
||||
ggml_backend_webgpu_wait(ctx, sub);
|
||||
}
|
||||
|
||||
/** End WebGPU Actions */
|
||||
@@ -805,7 +794,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
||||
std::cout << "\nggml_webgpu: gpu breakdown:\n";
|
||||
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
|
||||
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
|
||||
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
|
||||
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
|
||||
<< pct << "%)\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -978,14 +968,6 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
|
||||
if (decisions->i64_idx) {
|
||||
error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
|
||||
if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
||||
error_bufs->host_buf.Unmap();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
@@ -1018,8 +1000,10 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
};
|
||||
|
||||
if (decisions->i64_idx) {
|
||||
entries.push_back(
|
||||
{ .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
|
||||
entries.push_back({ .binding = 3,
|
||||
.buffer = ctx->set_rows_dev_error_buf,
|
||||
.offset = 0,
|
||||
.size = ctx->set_rows_dev_error_buf.GetSize() });
|
||||
}
|
||||
|
||||
uint32_t threads;
|
||||
@@ -1029,8 +1013,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
|
||||
}
|
||||
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
|
||||
error_bufs);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
|
||||
}
|
||||
|
||||
// Workgroup size is a common constant
|
||||
@@ -1108,12 +1091,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
use_fast = (src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q6_K:
|
||||
use_fast = true;
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
|
||||
use_fast = !is_vec;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -1187,17 +1184,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (use_fast && is_vec) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
} else if (use_fast) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// Fast-path tiled/subgroup calculations
|
||||
uint32_t wg_m, wg_n;
|
||||
uint32_t wg_m;
|
||||
uint32_t wg_n;
|
||||
if (decisions->use_subgroup_matrix) {
|
||||
uint32_t wg_m_sg_tile =
|
||||
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
|
||||
@@ -1215,7 +1213,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
} else { // legacy
|
||||
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_size = decisions->wg_size;
|
||||
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
@@ -1514,10 +1512,10 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
uint32_t dim = (uint32_t) dst->op_params[0];
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
@@ -1538,28 +1536,22 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
dim,
|
||||
(uint32_t)src0->ne[dim]
|
||||
(uint32_t) src0->ne[dim]
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0)
|
||||
},
|
||||
{
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1)
|
||||
},
|
||||
{
|
||||
.binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst)
|
||||
}
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
@@ -1569,9 +1561,9 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
@@ -1623,7 +1615,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
@@ -2172,19 +2169,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
case GGML_OP_SOFT_MAX:
|
||||
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
|
||||
case GGML_OP_UNARY:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_CLAMP:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_FILL:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_LOG:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SQR:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SQRT:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SIN:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_COS:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_PAD:
|
||||
@@ -2192,7 +2182,6 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
case GGML_OP_ARGMAX:
|
||||
return ggml_webgpu_argmax(ctx, src0, node);
|
||||
case GGML_OP_ARGSORT:
|
||||
return ggml_webgpu_argsort(ctx, src0, node);
|
||||
case GGML_OP_TOP_K:
|
||||
// we reuse the same argsort implementation for top_k
|
||||
return ggml_webgpu_argsort(ctx, src0, node);
|
||||
@@ -2214,33 +2203,51 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
|
||||
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
|
||||
|
||||
std::vector<webgpu_command> commands;
|
||||
std::vector<wgpu::FutureWaitInfo> futures;
|
||||
uint32_t num_batched_kernels = 0;
|
||||
std::vector<webgpu_command> commands;
|
||||
std::vector<webgpu_submission> subs;
|
||||
uint32_t num_batched_kernels = 0;
|
||||
bool contains_set_rows = false;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
|
||||
contains_set_rows = true;
|
||||
}
|
||||
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
|
||||
commands.push_back(*cmd);
|
||||
num_batched_kernels += cmd.value().num_kernels;
|
||||
}
|
||||
|
||||
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
|
||||
num_batched_kernels = 0;
|
||||
std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
|
||||
ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
|
||||
futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
|
||||
num_batched_kernels = 0;
|
||||
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
|
||||
// Process events and check for completed submissions
|
||||
ctx->global_ctx->instance.ProcessEvents();
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
|
||||
commands.clear();
|
||||
}
|
||||
}
|
||||
if (!commands.empty()) {
|
||||
auto new_futures =
|
||||
ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
|
||||
futures.insert(futures.end(), new_futures.begin(), new_futures.end());
|
||||
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
|
||||
commands.clear();
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, futures);
|
||||
// If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
|
||||
if (contains_set_rows) {
|
||||
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
|
||||
ctx->set_rows_host_error_buf.GetSize());
|
||||
wgpu::CommandBuffer set_rows_commands = encoder.Finish();
|
||||
ctx->global_ctx->queue.Submit(1, &set_rows_commands);
|
||||
ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
|
||||
ctx->set_rows_host_error_buf.GetSize());
|
||||
const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
|
||||
if (*error_data) {
|
||||
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
||||
}
|
||||
ctx->set_rows_host_error_buf.Unmap();
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, subs);
|
||||
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
@@ -2859,10 +2866,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
||||
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
|
||||
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
|
||||
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
|
||||
|
||||
@@ -11,7 +11,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
|
||||
shmem[idx + 2] = val.z;
|
||||
shmem[idx + 3] = val.w;
|
||||
}
|
||||
#endif
|
||||
#endif // VEC
|
||||
|
||||
#ifdef SCALAR
|
||||
#define VEC_SIZE 1
|
||||
@@ -23,7 +23,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
|
||||
fn store_shmem(val: f16, idx: u32) {
|
||||
shmem[idx] = val;
|
||||
}
|
||||
#endif
|
||||
#endif // SCALAR
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_FLOAT
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
@@ -40,7 +40,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC0_SHMEM_FLOAT
|
||||
|
||||
#ifdef INIT_SRC1_SHMEM_FLOAT
|
||||
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
|
||||
@@ -57,7 +57,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
|
||||
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC1_SHMEM_FLOAT
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
@@ -100,4 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC0_SHMEM_Q4_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f16(q_byte & 0xF) * d + m;
|
||||
let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
|
||||
shmem[shmem_idx + j * 2 + k] = q_lo;
|
||||
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q4_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_0
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
|
||||
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q5_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_1
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
|
||||
|
||||
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
|
||||
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q5_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 1u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d;
|
||||
shmem[shmem_idx + j * 2 + k] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q8_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d + m;
|
||||
shmem[shmem_idx + j * 2 + k] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q8_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q2_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 42u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
// Use standard thread layout instead of lane/row_group
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx + 40u];
|
||||
let dmin = src0[scale_idx + 41u];
|
||||
|
||||
// Decode the element at position k_in_block
|
||||
let block_of_32 = k_in_block / 32u;
|
||||
let pos_in_32 = k_in_block % 32u;
|
||||
|
||||
let q_b_idx = (block_of_32 / 4u) * 32u;
|
||||
let shift = (block_of_32 % 4u) * 2u;
|
||||
let k = (pos_in_32 / 16u) * 16u;
|
||||
let l = pos_in_32 % 16u;
|
||||
|
||||
let is = k_in_block / 16u;
|
||||
|
||||
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
|
||||
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
|
||||
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
|
||||
let sc = get_byte(sc_packed, is % 4u);
|
||||
|
||||
let dl = d * f16(sc & 0xFu);
|
||||
let ml = dmin * f16(sc >> 4u);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
let q_val = f16(qs_val) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q2_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q3_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 55u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx + 54u];
|
||||
|
||||
// Load and unpack scales
|
||||
let kmask1: u32 = 0x03030303u;
|
||||
let kmask2: u32 = 0x0f0f0f0fu;
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0u; i < 4u; i++) {
|
||||
let scale_0 = src0[scale_idx + 48u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
}
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
|
||||
|
||||
// Load hmask and qs arrays
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0u; i < 8u; i++) {
|
||||
let hmask_0 = src0[scale_idx + (2u*i)];
|
||||
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
|
||||
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
|
||||
}
|
||||
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16u; i++) {
|
||||
let qs_0 = src0[scale_idx + 16u + (2u*i)];
|
||||
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
|
||||
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
|
||||
}
|
||||
|
||||
let half = k_in_block / 128u; // 0 or 1
|
||||
let pos_in_half = k_in_block % 128u; // 0-127
|
||||
let shift_group = pos_in_half / 32u; // 0-3
|
||||
let pos_in_32 = pos_in_half % 32u; // 0-31
|
||||
let k_group = pos_in_32 / 16u; // 0 or 1
|
||||
let l = pos_in_32 % 16u; // 0-15
|
||||
|
||||
let q_b_idx = half * 32u; // 0 or 32
|
||||
let shift = shift_group * 2u; // 0, 2, 4, 6
|
||||
let k = k_group * 16u; // 0 or 16
|
||||
let is = k_in_block / 16u; // 0-15
|
||||
|
||||
// m increments every 32 elements across entire 256 element block
|
||||
let m_shift = k_in_block / 32u; // 0-7
|
||||
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
|
||||
|
||||
let sc = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let dl = d * (f16(sc) - 32.0);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
|
||||
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
let q_val = (f16(qs_val) - f16(hm)) * dl;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // INIT_SRC0_SHMEM_Q3_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 72u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
}
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
// Outer loop over 64-element groups (alternating q_b_idx)
|
||||
// Inner loop over 2 shifts per group
|
||||
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
|
||||
let pos_in_64 = k_in_block % 64u; // 0-63
|
||||
let shift_group = pos_in_64 / 32u; // 0 or 1
|
||||
let l = pos_in_64 % 32u; // 0-31
|
||||
|
||||
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
|
||||
let shift = shift_group * 4u; // 0 or 4
|
||||
let is = k_in_block / 32u; // 0-7
|
||||
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
}
|
||||
|
||||
let dl = d * f16(sc);
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
|
||||
let q_val = f16(qs_val) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q4_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 88u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
}
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
|
||||
// But u increments EVERY 32 elements (after each l loop)
|
||||
let group_of_64 = k_in_block / 64u; // 0-3
|
||||
let pos_in_64 = k_in_block % 64u; // 0-63
|
||||
let shift_group = pos_in_64 / 32u; // 0 or 1
|
||||
let l = pos_in_64 % 32u; // 0-31
|
||||
|
||||
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
|
||||
let shift = shift_group * 4u; // 0 or 4
|
||||
let is = k_in_block / 32u; // 0-7
|
||||
|
||||
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
|
||||
let u_shift = k_in_block / 32u; // 0-7
|
||||
let u: u32 = 1u << u_shift;
|
||||
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
}
|
||||
|
||||
let dl = d * f16(sc);
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
|
||||
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
|
||||
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
|
||||
|
||||
let qh_byte = get_byte(qh_packed, l % 4u);
|
||||
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
|
||||
|
||||
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // INIT_SRC0_SHMEM_Q5_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q6_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let half = k_in_block / 128u;
|
||||
let pos_in_half = k_in_block % 128u;
|
||||
let quarter = pos_in_half / 32u;
|
||||
let l = pos_in_half % 32u;
|
||||
|
||||
let ql_b_idx = half * 64u;
|
||||
let qh_b_idx = half * 32u;
|
||||
let sc_b_idx = half * 8u;
|
||||
|
||||
// Load only ql13 word needed
|
||||
let ql13_flat = ql_b_idx + l;
|
||||
let ql13_word = ql13_flat / 4u;
|
||||
let ql13 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql13_word],
|
||||
src0[scale_idx + 2u * ql13_word + 1u]
|
||||
));
|
||||
let ql13_b = get_byte(ql13, ql13_flat % 4u);
|
||||
|
||||
// Load only ql24 word needed
|
||||
let ql24_flat = ql_b_idx + l + 32u;
|
||||
let ql24_word = ql24_flat / 4u;
|
||||
let ql24 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql24_word],
|
||||
src0[scale_idx + 2u * ql24_word + 1u]
|
||||
));
|
||||
let ql24_b = get_byte(ql24, ql24_flat % 4u);
|
||||
|
||||
// Load only qh word needed
|
||||
let qh_flat = qh_b_idx + l;
|
||||
let qh_word = qh_flat / 4u;
|
||||
let qh = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 64u + 2u * qh_word],
|
||||
src0[scale_idx + 64u + 2u * qh_word + 1u]
|
||||
));
|
||||
let qh_b = get_byte(qh, qh_flat % 4u);
|
||||
|
||||
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
|
||||
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
|
||||
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
|
||||
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
|
||||
|
||||
// Load only the scale word needed
|
||||
let is = l / 16u;
|
||||
let sc_idx = sc_b_idx + is + quarter * 2u;
|
||||
let sc_word = sc_idx / 4u;
|
||||
let sc = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 96u + 2u * sc_word],
|
||||
src0[scale_idx + 96u + 2u * sc_word + 1u]
|
||||
));
|
||||
let sc_val = get_byte_i32(sc, sc_idx % 4u);
|
||||
|
||||
let d = src0[scale_idx + 104u];
|
||||
|
||||
var q_val: f16;
|
||||
if (quarter == 0u) {
|
||||
q_val = q1;
|
||||
} else if (quarter == 1u) {
|
||||
q_val = q2;
|
||||
} else if (quarter == 2u) {
|
||||
q_val = q3;
|
||||
} else {
|
||||
q_val = q4;
|
||||
}
|
||||
|
||||
shmem[elem_idx] = d * f16(sc_val) * q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q6_K
|
||||
|
||||
@@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 {
|
||||
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
|
||||
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "common_decls.tmpl"
|
||||
@@ -84,6 +83,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 10u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = f32(src0[scale_idx + 1u]);
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
let q_lo = f32(q_byte & 0xF) * d + m;
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q5_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 11u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q5_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 12u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);
|
||||
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 17u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1 + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q8_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 18u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + f32(m);
|
||||
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q6_K
|
||||
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
|
||||
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
|
||||
let aligned = byte_offset & ~3u;
|
||||
let idx = bbase + aligned / 2u;
|
||||
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
|
||||
}
|
||||
|
||||
fn byte_of(v: u32, b: u32) -> u32 {
|
||||
return (v >> (b * 8u)) & 0xFFu;
|
||||
}
|
||||
|
||||
fn sbyte_of(v: u32, b: u32) -> i32 {
|
||||
let raw = i32((v >> (b * 8u)) & 0xFFu);
|
||||
return select(raw, raw - 256, raw >= 128);
|
||||
}
|
||||
|
||||
fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let tid = tig / 2u;
|
||||
let ix = tig % 2u;
|
||||
let ip = tid / 8u;
|
||||
let il = tid % 8u;
|
||||
let l0 = 4u * il;
|
||||
let is = 8u * ip + l0 / 16u;
|
||||
|
||||
let y_offset = 128u * ip + l0;
|
||||
let q_offset_l = 64u * ip + l0;
|
||||
let q_offset_h = 32u * ip + l0;
|
||||
|
||||
let nb = tile_size / BLOCK_SIZE;
|
||||
let k_block_start = k_outer / BLOCK_SIZE;
|
||||
|
||||
// Aligned scale byte position (is can be odd)
|
||||
let sc_base_byte = 192u + (is & ~3u);
|
||||
let sc_byte_pos = is & 3u;
|
||||
|
||||
var local_sum = 0.0;
|
||||
|
||||
for (var i = ix; i < nb; i += 2u) {
|
||||
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
|
||||
|
||||
let d_raw = load_u32_at(bbase, 208u);
|
||||
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
|
||||
|
||||
let ql1_u32 = load_u32_at(bbase, q_offset_l);
|
||||
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
|
||||
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
|
||||
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
|
||||
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
|
||||
|
||||
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
|
||||
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
|
||||
let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
|
||||
let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
|
||||
|
||||
var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
|
||||
for (var l = 0u; l < 4u; l++) {
|
||||
let y_base = i * BLOCK_SIZE + y_offset + l;
|
||||
let yl0 = f32(shared_vector[y_base]);
|
||||
let yl1 = f32(shared_vector[y_base + 32u]);
|
||||
let yl2 = f32(shared_vector[y_base + 64u]);
|
||||
let yl3 = f32(shared_vector[y_base + 96u]);
|
||||
|
||||
let q1b = byte_of(ql1_u32, l);
|
||||
let q2b = byte_of(ql2_u32, l);
|
||||
let qhb = byte_of(qh_u32, l);
|
||||
|
||||
let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
|
||||
let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
|
||||
let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32);
|
||||
let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);
|
||||
|
||||
sums[0] += yl0 * dq0;
|
||||
sums[1] += yl1 * dq1;
|
||||
sums[2] += yl2 * dq2;
|
||||
sums[3] += yl3 * dq3;
|
||||
}
|
||||
|
||||
local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
|
||||
sums[2] * f32(sc4) + sums[3] * f32(sc6));
|
||||
}
|
||||
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
@@ -191,4 +478,3 @@ fn main(
|
||||
dst[dst_idx / VEC_SIZE] = store_val(group_base);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -177,6 +177,8 @@ class Keys:
|
||||
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
|
||||
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
|
||||
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
||||
KEY_LENGTH_SWA = "{arch}.attention.key_length_swa"
|
||||
VALUE_LENGTH_SWA = "{arch}.attention.value_length_swa"
|
||||
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
|
||||
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
|
||||
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
|
||||
@@ -188,6 +190,7 @@ class Keys:
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
DIMENSION_COUNT_SWA = "{arch}.rope.dimension_count_swa"
|
||||
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
|
||||
FREQ_BASE = "{arch}.rope.freq_base"
|
||||
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
|
||||
|
||||
@@ -773,6 +773,12 @@ class GGUFWriter:
|
||||
def add_value_length_mla(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
|
||||
|
||||
def add_key_length_swa(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KEY_LENGTH_SWA.format(arch=self.arch), length)
|
||||
|
||||
def add_value_length_swa(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_LENGTH_SWA.format(arch=self.arch), length)
|
||||
|
||||
def add_indexer_head_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count)
|
||||
|
||||
@@ -946,6 +952,9 @@ class GGUFWriter:
|
||||
def add_rope_dimension_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_rope_dimension_count_swa(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT_SWA.format(arch=self.arch), count)
|
||||
|
||||
def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
|
||||
self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
|
||||
|
||||
|
||||
@@ -230,11 +230,14 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_SWA, "%s.attention.key_length_swa" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_SWA, "%s.attention.value_length_swa" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
|
||||
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
|
||||
|
||||
@@ -234,11 +234,14 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_SWA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_SWA,
|
||||
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
|
||||
LLM_KV_ATTENTION_INDEXER_TOP_K,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_FREQ_BASE_SWA,
|
||||
|
||||
+12
-8
@@ -2876,19 +2876,23 @@ llama_context * llama_init_from_model(
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_k);
|
||||
if (model->hparams.n_embd_head_k % blck_size != 0) {
|
||||
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
||||
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
|
||||
return nullptr;
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
|
||||
LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
||||
__func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_v);
|
||||
if (model->hparams.n_embd_head_v % blck_size != 0) {
|
||||
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
|
||||
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
|
||||
return nullptr;
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
|
||||
LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n",
|
||||
__func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -601,7 +601,7 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||
}
|
||||
const char * int_end = parse_int(pos);
|
||||
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
|
||||
uint64_t min_times = std::stoull(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
|
||||
uint64_t max_times = UINT64_MAX; // default: no max limit
|
||||
@@ -614,7 +614,7 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
|
||||
if (is_digit_char(*pos)) {
|
||||
const char * int_end = parse_int(pos);
|
||||
max_times = std::stoul(std::string(pos, int_end - pos));
|
||||
max_times = std::stoull(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
}
|
||||
|
||||
|
||||
+3
-3
@@ -849,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
ubatch (params.ubatch),
|
||||
n_embd (hparams.n_embd),
|
||||
n_layer (hparams.n_layer),
|
||||
n_rot (hparams.n_rot),
|
||||
n_rot (hparams.n_rot()),
|
||||
n_ctx (cparams.n_ctx),
|
||||
n_head (hparams.n_head()),
|
||||
n_head_kv (hparams.n_head_kv()),
|
||||
n_embd_head_k (hparams.n_embd_head_k),
|
||||
n_embd_head_k (hparams.n_embd_head_k()),
|
||||
n_embd_k_gqa (hparams.n_embd_k_gqa()),
|
||||
n_embd_head_v (hparams.n_embd_head_v),
|
||||
n_embd_head_v (hparams.n_embd_head_v()),
|
||||
n_embd_v_gqa (hparams.n_embd_v_gqa()),
|
||||
n_expert (hparams.n_expert),
|
||||
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
|
||||
|
||||
+28
-4
@@ -62,6 +62,14 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const {
|
||||
return n_head/n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_rot(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return is_swa(il) ? n_rot_swa : n_rot_full;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_inp() const {
|
||||
uint32_t n_embd_inp = n_embd;
|
||||
|
||||
@@ -76,16 +84,32 @@ uint32_t llama_hparams::n_embd_out() const {
|
||||
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_k(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_v(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
return n_embd_head_k * n_head_kv;
|
||||
return n_embd_head_k(il) * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
return n_embd_head_v * n_head_kv;
|
||||
return n_embd_head_v(il) * n_head_kv;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_n_embd_k_gqa_variable() const {
|
||||
@@ -197,11 +221,11 @@ bool llama_hparams::is_mla() const {
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_k_mla() const {
|
||||
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
|
||||
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k();
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_head_v_mla() const {
|
||||
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
|
||||
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v();
|
||||
}
|
||||
|
||||
bool llama_hparams::has_kv(uint32_t il) const {
|
||||
|
||||
+16
-3
@@ -44,13 +44,20 @@ struct llama_hparams {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_layer;
|
||||
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
|
||||
uint32_t n_rot;
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_expert = 0;
|
||||
uint32_t n_expert_used = 0;
|
||||
uint32_t n_rel_attn_bkts = 0;
|
||||
|
||||
// different head size for full_attention and SWA layers
|
||||
uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_embd_head_k_swa;
|
||||
uint32_t n_embd_head_v_swa;
|
||||
|
||||
// different RoPE dimensions for full_attention and SWA layers
|
||||
uint32_t n_rot_full;
|
||||
uint32_t n_rot_swa;
|
||||
|
||||
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
||||
uint32_t n_embd_head_k_mla_impl = 0;
|
||||
uint32_t n_embd_head_v_mla_impl = 0;
|
||||
@@ -247,12 +254,18 @@ struct llama_hparams {
|
||||
|
||||
uint32_t n_gqa(uint32_t il = 0) const;
|
||||
|
||||
uint32_t n_rot(uint32_t il = 0) const;
|
||||
|
||||
// dimension of main + auxiliary input embeddings
|
||||
uint32_t n_embd_inp() const;
|
||||
|
||||
// dimension of output embeddings
|
||||
uint32_t n_embd_out() const;
|
||||
|
||||
// dimension of key/value embeddings for each head (per layer)
|
||||
uint32_t n_embd_head_k(uint32_t il = 0) const;
|
||||
uint32_t n_embd_head_v(uint32_t il = 0) const;
|
||||
|
||||
// dimension of key embeddings across all k-v heads
|
||||
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
|
||||
|
||||
|
||||
+14
-16
@@ -1033,8 +1033,8 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k
|
||||
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
||||
|
||||
return ggml_view_4d(ctx, k,
|
||||
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(k->type, hparams.n_embd_head_k),
|
||||
hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(k->type, hparams.n_embd_head_k(il)),
|
||||
ggml_row_size(k->type, n_embd_k_gqa),
|
||||
ggml_row_size(k->type, n_embd_k_gqa*kv_size),
|
||||
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
|
||||
@@ -1056,8 +1056,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
|
||||
if (!v_trans) {
|
||||
// note: v->nb[1] <= v->nb[2]
|
||||
return ggml_view_4d(ctx, v,
|
||||
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
||||
hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1]
|
||||
ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
|
||||
ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
|
||||
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
|
||||
@@ -1065,8 +1065,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
|
||||
|
||||
// note: v->nb[1] > v->nb[2]
|
||||
return ggml_view_4d(ctx, v,
|
||||
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
|
||||
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
|
||||
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns,
|
||||
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1]
|
||||
ggml_row_size(v->type, kv_size), // v->nb[2]
|
||||
ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
|
||||
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
|
||||
@@ -1544,7 +1544,8 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const {
|
||||
float freq_scale,
|
||||
uint32_t il) const {
|
||||
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
||||
|
||||
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
||||
@@ -1552,7 +1553,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
||||
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
||||
|
||||
const auto & n_rot = hparams.n_rot;
|
||||
const auto & n_rot = hparams.n_rot(il);
|
||||
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
||||
// @ngxson : this is a workaround
|
||||
// for M-RoPE, we want to rotate the whole vector when doing KV shift
|
||||
@@ -1606,13 +1607,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
auto * ctx = res->get_ctx();
|
||||
auto * gf = res->get_gf();
|
||||
|
||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
const auto & n_rot = hparams.n_rot;
|
||||
|
||||
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||
@@ -1626,6 +1620,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
||||
const auto n_rot = hparams.n_rot(il);
|
||||
const auto n_embd_head_k = hparams.n_embd_head_k(il);
|
||||
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
|
||||
|
||||
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
|
||||
@@ -1638,7 +1636,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
ggml_row_size(layer.k->type, n_embd_nope));
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
@@ -264,7 +264,8 @@ private:
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const;
|
||||
float freq_scale,
|
||||
uint32_t il) const;
|
||||
|
||||
ggml_cgraph * build_graph_shift(
|
||||
llm_graph_result * res,
|
||||
|
||||
@@ -918,7 +918,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int n_embd_head = hparams.n_embd_head_v;
|
||||
const int n_embd_head = hparams.n_embd_head_v();
|
||||
const int n_head = hparams.n_head();
|
||||
ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512);
|
||||
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
|
||||
|
||||
@@ -186,8 +186,10 @@ void llama_model_saver::add_kv_from_model() {
|
||||
add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true);
|
||||
add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
|
||||
add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
|
||||
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k);
|
||||
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v);
|
||||
add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full);
|
||||
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full);
|
||||
add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
|
||||
add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
|
||||
add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
@@ -199,7 +201,8 @@ void llama_model_saver::add_kv_from_model() {
|
||||
|
||||
const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
|
||||
|
||||
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
|
||||
add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full);
|
||||
add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa);
|
||||
add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train);
|
||||
// add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name
|
||||
add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
|
||||
|
||||
+51
-39
@@ -459,26 +459,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
|
||||
// gpt-j n_rot = rotary_dim
|
||||
|
||||
hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
|
||||
hparams.n_embd_head_k_full = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false);
|
||||
|
||||
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
||||
hparams.n_embd_head_v_full = hparams.n_embd / hparams.n_head();
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false);
|
||||
|
||||
// sanity check for n_rot (optional)
|
||||
hparams.n_rot = hparams.n_embd_head_k;
|
||||
hparams.n_rot_full = hparams.n_embd_head_k_full;
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full, false);
|
||||
|
||||
if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) {
|
||||
if (hparams.n_rot != hparams.n_embd_head_k) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
||||
if (hparams.n_rot_full != hparams.n_embd_head_k_full) {
|
||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot_full, hparams.n_embd_head_k_full));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
hparams.n_rot = 0;
|
||||
hparams.n_embd_head_k = 0;
|
||||
hparams.n_embd_head_v = 0;
|
||||
hparams.n_rot_full = 0;
|
||||
hparams.n_embd_head_k_full = 0;
|
||||
hparams.n_embd_head_v_full = 0;
|
||||
}
|
||||
|
||||
// head size and n_rot for SWA layers
|
||||
{
|
||||
hparams.n_embd_head_k_swa = hparams.n_embd_head_k_full;
|
||||
hparams.n_embd_head_v_swa = hparams.n_embd_head_v_full;
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa, false);
|
||||
|
||||
hparams.n_rot_swa = hparams.n_rot_full;
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa, false);
|
||||
}
|
||||
|
||||
// for differentiating model types
|
||||
@@ -1114,10 +1125,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
// Load attention parameters
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
||||
} break;
|
||||
case LLM_ARCH_PLAMO3:
|
||||
{
|
||||
@@ -1212,7 +1219,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
|
||||
hparams.f_attention_scale = type == LLM_TYPE_27B
|
||||
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
@@ -1245,7 +1252,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
|
||||
hparams.f_attention_scale = type == LLM_TYPE_27B
|
||||
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
{
|
||||
@@ -1294,7 +1301,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
case 24: type = LLM_TYPE_0_3B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k()));
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
@@ -2487,7 +2494,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl);
|
||||
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
|
||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot);
|
||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda);
|
||||
|
||||
@@ -2518,6 +2524,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
|
||||
// full_attention layer only use half of the RoPE dimensions
|
||||
hparams.n_rot_full = hparams.n_rot_full / 2;
|
||||
|
||||
// MoE + SWA parameters
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||
@@ -2661,13 +2670,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k();
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v();
|
||||
const int64_t n_ff = hparams.n_ff();
|
||||
const int64_t n_embd_gqa = n_embd_v_gqa;
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_token_types = vocab.n_token_types();
|
||||
const int64_t n_rot = hparams.n_rot;
|
||||
const int64_t n_rot = hparams.n_rot();
|
||||
const int64_t n_expert = hparams.n_expert;
|
||||
const int64_t n_expert_used = hparams.n_expert_used;
|
||||
const int64_t n_ctx_train = hparams.n_ctx_train;
|
||||
@@ -2967,8 +2976,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_MINICPM3:
|
||||
{
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
|
||||
|
||||
const int64_t q_lora_rank = hparams.n_lora_q;
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
@@ -3840,8 +3849,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
|
||||
|
||||
// attention parameters
|
||||
const uint32_t qk_dim = hparams.n_embd_head_k;
|
||||
const uint32_t v_dim = hparams.n_embd_head_v;
|
||||
const uint32_t qk_dim = hparams.n_embd_head_k();
|
||||
const uint32_t v_dim = hparams.n_embd_head_v();
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
@@ -3901,8 +3910,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_PLAMO3:
|
||||
{
|
||||
const int64_t head_dim_q = hparams.n_embd_head_k;
|
||||
const int64_t head_dim_v = hparams.n_embd_head_v;
|
||||
const int64_t head_dim_q = hparams.n_embd_head_k();
|
||||
const int64_t head_dim_v = hparams.n_embd_head_v();
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
@@ -4649,7 +4658,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
{
|
||||
const uint32_t head_dim = hparams.n_embd_head_k;
|
||||
const uint32_t head_dim = hparams.n_embd_head_k();
|
||||
const int64_t n_qo_dim = n_head * head_dim;
|
||||
const int64_t n_kv_dim = n_head_kv * head_dim;
|
||||
|
||||
@@ -4878,7 +4887,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
|
||||
GGML_ASSERT(n_embd_head_qk_nope >= 1);
|
||||
|
||||
@@ -4957,8 +4966,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_PLM:
|
||||
{
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -5396,7 +5405,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
|
||||
|
||||
const int64_t q_lora_rank = hparams.n_lora_q;
|
||||
@@ -5680,7 +5689,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int64_t n_expert = hparams.n_expert;
|
||||
const int64_t n_expert_used = hparams.n_expert_used;
|
||||
const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp;
|
||||
const int64_t head_dim = hparams.n_embd_head_k;
|
||||
const int64_t head_dim = hparams.n_embd_head_k();
|
||||
const int64_t n_qo_dim = n_head * head_dim;
|
||||
const int64_t n_kv_dim = n_head_kv * head_dim;
|
||||
|
||||
@@ -6968,7 +6977,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
// Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA)
|
||||
// Note: hparams.n_rot may be 72 (from conversion) but actual is 64
|
||||
const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim
|
||||
const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim
|
||||
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0);
|
||||
// Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled)
|
||||
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i),
|
||||
@@ -7339,7 +7348,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
// ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer.
|
||||
uint32_t n_rot_max = 0;
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
n_rot_max = std::max(n_rot_max, hparams.n_rot);
|
||||
n_rot_max = std::max(n_rot_max, hparams.n_rot(i));
|
||||
}
|
||||
if (n_rot_max == 0) {
|
||||
n_rot_max = n_rot;
|
||||
@@ -7674,11 +7683,11 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
|
||||
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full);
|
||||
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
||||
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full);
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
|
||||
@@ -7702,6 +7711,9 @@ void llama_model::print_info() const {
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa);
|
||||
LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa);
|
||||
LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa);
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
||||
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
||||
|
||||
+466
-285
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
|
||||
|
||||
llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
|
||||
llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
|
||||
llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
+2
-2
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
|
||||
llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
#include <float.h>
|
||||
|
||||
llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
|
||||
llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * inpL;
|
||||
ggml_tensor * cur;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
const float f_logit_scale = hparams.f_logit_scale;
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
|
||||
llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
const float f_logit_scale = hparams.f_logit_scale;
|
||||
|
||||
|
||||
+3
-3
@@ -1,11 +1,11 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
+3
-3
@@ -3,10 +3,10 @@
|
||||
|
||||
|
||||
llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -8,7 +8,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
//copied from qwen2
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
|
||||
llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
template <bool iswa>
|
||||
llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
|
||||
llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
template <bool iswa>
|
||||
llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model),
|
||||
n_embd_head(model.hparams.n_embd_head_k),
|
||||
n_embd_head(model.hparams.n_embd_head_k()),
|
||||
n_embd_altup(model.hparams.n_embd_altup),
|
||||
n_altup(model.hparams.n_altup),
|
||||
i_altup_act(model.hparams.i_altup_act) {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
+2
-2
@@ -3,10 +3,10 @@
|
||||
|
||||
|
||||
llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
+2
-2
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * pos;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
|
||||
llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -5,10 +5,10 @@ llm_build_granite::llm_build_granite(
|
||||
const llm_graph_params & params)
|
||||
: llm_graph_context(params) {
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
+3
-3
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_chunk_expert = n_expert / hparams.n_group_experts;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
+2
-2
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
// JAIS-2 model graph builder
|
||||
// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings
|
||||
llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -102,7 +102,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
// qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
|
||||
// Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot; // config.qk_rope_head_dim
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot(); // config.qk_rope_head_dim
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; // 192 - 64 = 128
|
||||
// Attention scale for MLA
|
||||
const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
|
||||
|
||||
+1
-1
@@ -39,7 +39,7 @@ llm_build_lfm2<iswa>::llm_build_lfm2(const llama_model & model, const llm_graph_
|
||||
inp_attn_type * inp_attn,
|
||||
int il) -> ggml_tensor * {
|
||||
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
|
||||
const auto n_embd_head = hparams.n_embd_head_v;
|
||||
const auto n_embd_head = hparams.n_embd_head_v();
|
||||
const auto n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
auto * q = build_lora_mm(model.layers[il].wq, cur);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
// LLaDA is similar to LLaMA but uses non-causal attention for diffusion
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
template <bool embed>
|
||||
llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -168,8 +168,9 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
GGML_ASSERT(d_inner % n_head == 0);
|
||||
GGML_ASSERT(d_inner % (n_group*d_state) == 0);
|
||||
GGML_ASSERT(d_inner % n_head == 0);
|
||||
GGML_ASSERT(d_inner % d_state == 0);
|
||||
GGML_ASSERT(d_inner % n_group == 0);
|
||||
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
+13
-13
@@ -5,10 +5,10 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
|
||||
const int64_t n_embd_base = 256;
|
||||
const float scale_embd = 12.0f;
|
||||
const float scale_depth = 1.4f;
|
||||
const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
|
||||
const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k()));
|
||||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
@@ -51,21 +51,21 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
|
||||
LLM_NORM_RMS, il);
|
||||
cb(q, "q", il);
|
||||
|
||||
// {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
|
||||
// {q_lora_rank, n_head * hparams.n_embd_head_k()} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k(), n_tokens}
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
|
||||
cb(q, "q", il);
|
||||
|
||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||
ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k),
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k()),
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k() * n_head),
|
||||
0);
|
||||
cb(q_nope, "q_nope", il);
|
||||
|
||||
// and {n_head * n_embd_head_qk_rope, n_tokens}
|
||||
ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k),
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k()),
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k() * n_head),
|
||||
ggml_row_size(q->type, n_embd_head_qk_nope));
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
@@ -97,15 +97,15 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
|
||||
|
||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||
ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v()),
|
||||
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v())),
|
||||
0);
|
||||
cb(k_nope, "k_nope", il);
|
||||
|
||||
// and {n_head * n_embd_head_v, n_tokens}
|
||||
ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
||||
ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v(), n_head, n_tokens,
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())),
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())*n_head),
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
// GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
// GGML_ASSERT(n_embd_head == n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
GGML_ASSERT(n_embd_head == n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
+2
-2
@@ -3,10 +3,10 @@
|
||||
|
||||
|
||||
llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * pos;
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user