Compare commits

..

26 Commits

Author SHA1 Message Date
Adrien Gallouët 08f3f4a8a3 ggml : cleanup path_str() (#18928)
- Remove pragmas as `std::codecvt_utf8` is not used.
- Avoid implicit `strlen()`.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-01-20 11:42:49 +01:00
Georgi Gerganov 271191906c metal : enable FA for MLA heads (#18950) 2026-01-20 12:21:28 +02:00
Daniel Bevenius 7dee9ff59a convert : use n_groups instead of hardcoded values in reshape (#18929)
* convert : use n_groups instead of hardcoded values in reshape

This commit modifies the conversion script for NemotronHModel to use
the 'n_groups' hyperparameter, and allow Python to calculate the the
last dimension, using -1, when reshaping the 'mixer.norm.weight' tensor.

* use self.n_group instead of self.hparams["n_groups"]
2026-01-20 06:55:24 +01:00
Xuan-Son Nguyen 6df686bee6 server : refactor oai_parser_opt, move it to server_chat_params (#18937)
* server_chat_params

* move chat format into CLI

* use meta whenever possible

* clean up, no more chatml fallback
2026-01-19 23:28:01 +01:00
ddh0 1706a6d7c6 convert : support Glm4MoeLite (#18936)
* initial commit for branch

* add glm-4.7-flash, move tokenizer hash

* use `glm4` pretok

* silence flake8 E302 (CI)

* apply review feedback

* add <|user|> as eog

* also add EOG `<|observation|>`

* revert llama-vocab

* inherit vocab from glm4

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-01-19 23:09:20 +01:00
Sigbjørn Skjæret 959ecf7f23 jinja : fix undefined keys and attributes and int/float as bool (#18924)
* fix undefined keys and attributes

* add falsy tests

* as_bool for integers and floats

* more falsy/truthy tests

* --typo
2026-01-19 20:29:43 +01:00
Sigbjørn Skjæret 4037093c66 ci : run test-jinja -py on high perf [no ci] (#18916) 2026-01-19 20:29:15 +01:00
Lennart Austenfeld 18361c579c server: fix memory reservations in populate_token_probs (#18787) 2026-01-19 19:13:31 +01:00
Georgi Gerganov 365a3e8c31 ggml : add ggml_build_forward_select (#18550)
* ggml : add ggml_build_forward_select

* cuda : adapt CUDA graph compat to new feature

* vulkan : update logic to handle command buffer closing

* ggml : check compute for fusion

* ggml : add comment
2026-01-19 20:03:19 +02:00
Daniel Bevenius 3d55846a5c model-conversion : add BUILD_DIR variable to run-converted-model scripts (#18927)
This commit adds a BUILD_DIR variable to the scripts used for running
converted models.

The motivation for this is that currently the `build` directory is
hardcoded and it can be useful to specify a different build directory,
with builds for different configurations.
2026-01-19 13:12:38 +01:00
Julius Tischbein 287a33017b llama : Extend fallback, fix fileno for dio file, exclude case that mmap uses dio file (#18887) 2026-01-18 18:35:57 +02:00
Francisco Herrera 293a1565dc docs: add linux to index (#18907) 2026-01-18 18:03:35 +08:00
Xuan-Son Nguyen fe44d35574 tests : add test-jinja -py option for cross-checking (#18906)
* tests : add test-jinja -py option or cross-checking

* Update tests/test-jinja.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix + add source

* SandboxedEnvironment

* fix array.map case

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-01-18 08:14:27 +01:00
Sigbjørn Skjæret bbcdac0189 jinja : fix object item order (and properly implement dictsort) (#18904)
* fix object item order

* as_ordered_object

* copy whole object
2026-01-18 03:40:06 +01:00
Sigbjørn Skjæret d03c45c9c5 jinja : attribute support for join, map and sort (#18883)
* support negative array index and default value

* attribute support (int and str) for join, map and sort

* add tests

* update CODEOWNERS

* improve fixme sorting comment
2026-01-18 02:53:01 +01:00
Sigbjørn Skjæret 10c98cbdf6 jinja : add missing tojson filter for bool (#18900)
* add missing tojson for bool

* add more literal tests
2026-01-18 01:05:09 +01:00
Sigbjørn Skjæret 420960ab92 jinja : fix lexing of float literals with sign (#18901)
* fix lexing of float literals with sign

* add test

* consume_numeric
2026-01-18 00:57:51 +01:00
Xuan-Son Nguyen f55b033ae6 jinja: correct member access rule (#18905) 2026-01-18 00:48:55 +01:00
lhez d1b4757ded opencl: fix q6_K mv for m=1 (#18893) 2026-01-17 13:50:32 -08:00
Sigbjørn Skjæret 57c0beaed0 ci : add label for jinja changes (#18903) 2026-01-17 21:52:02 +01:00
Georgi Gerganov 2fbde785bc kv-cache : optimize KQ mask construction (#18842)
* kv-cache : optimize KQ mask construction

* cont : add explanation + improve

* cont : fix
2026-01-17 15:42:42 +02:00
Reese Levine a89002f07b ggml webgpu: support for backend sampling (#18880)
* ggml webgpu: add SOFTPLUS unary operator

Implements SOFTPLUS (log(1 + exp(x))) with f16/f32 support. Uses f32
precision for intermediate calculations to prevent f16 overflow.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support
* Follow Vulkan backend numerical stability pattern

* ggml webgpu: add EXPM1 unary operator

Implements EXPM1 (exp(x) - 1) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add FLOOR unary operator

Implements FLOOR (rounds down to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add CEIL unary operator

Implements CEIL (rounds up to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add ROUND unary operator

Implements ROUND (rounds to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add TRUNC unary operator

Implements TRUNC (truncates towards zero) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* docs : update WebGPU support for unary operators (FLOOR, CEIL, ROUND, TRUNC, EXPM1, SOFTPLUS)

* Updates to webgpu get_memory

* Add argmax

* Add argmax,cumsum,sum,sum_rows

* Add necessary CPY/GET_ROWS operators

* Support for argsort using multi-pass strategy

* Update set_rows for i32 indices, move to pre-wgsl

* Port unary operators to pre-wgsl and support FILL

* Implement PAD

* Add support for top-k

* clean up, scope pipeline init mutex

* fix newline

* Add support for log

* Update LOG for better precision, and ops doc

---------

Co-authored-by: Abhijit Ramesh <abhijitramesh2k@gmail.com>
2026-01-16 16:12:43 -08:00
Thore Koritzius 388ce82241 ggml : extend ggml_pool_1d + metal (#16429)
* chore: resolve conflicts

* feat: ggml metal impl

* fix: ggml_metal_kargs_pool_1d struct

* fix: require contiguous input

* chore: test pool_1d

* chore: limit pool1d test cases to p0=0 and s0=k0 to conform with asserts

* chore: add p0 and s0 to testing

* fix: allow padding for cpu and metal

* Update ggml/src/ggml-metal/ggml-metal.metal

* fix: correct single-threaded loop

* ggml : cleanup

* tests : add ne[1] != 1 tests

* fix: ne[1] handling in np

* cont : fixes

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-01-16 16:59:56 +02:00
hipudding 6ba6a3c76f docs : update ops.md for CANN backend (#18654) 2026-01-16 13:32:17 +01:00
Perry Naseck 0802d4cfb3 ggml-blas: hide warnings from included BLAS headers (#18818)
* fix compile def openblas, blis for compat libs, nvpl compile def, warn if no blas vendor set

* ggml-blas: hide warnings from included BLAS headers
2026-01-16 13:38:25 +02:00
Tarek Dakhran c945aaaef2 mtmd : Fix ASR for LFM2.5-Audio-1.5B (#18876) 2026-01-16 11:23:08 +01:00
73 changed files with 26183 additions and 13570 deletions
+4 -1
View File
@@ -89,7 +89,10 @@ nix:
embedding:
- changed-files:
- any-glob-to-any-file: examples/embedding/
jinja parser:
- changed-files:
- any-glob-to-any-file:
- common/jinja/**
Ascend NPU:
- changed-files:
- any-glob-to-any-file:
+1
View File
@@ -15,6 +15,7 @@
/common/common.* @ggerganov
/common/console.* @ggerganov
/common/http.* @angt
/common/jinja/ @ngxson @CISC @aldehir
/common/llguidance.* @ggerganov
/common/log.* @ggerganov
/common/peg-parser.* @aldehir
+1 -1
View File
@@ -254,7 +254,7 @@ function gg_run_ctest_release {
(time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log
if [ -z ${GG_BUILD_LOW_PERF} ]; then
(time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log
(time ctest --output-on-failure -L 'main|python' ) 2>&1 | tee -a $OUT/${ci}-ctest.log
else
(time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
fi
+7 -7
View File
@@ -601,18 +601,18 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp
return tmpls->has_explicit_template;
}
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
if (variant != nullptr) {
if (strcmp(variant, "tool_use") == 0) {
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) {
if (!variant.empty()) {
if (variant == "tool_use") {
if (tmpls->template_tool_use) {
return tmpls->template_tool_use->source().c_str();
return tmpls->template_tool_use->source();
}
return nullptr;
return "";
} else {
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str());
}
}
return tmpls->template_default->source().c_str();
return tmpls->template_default->source();
}
common_chat_templates_ptr common_chat_templates_init(
+1 -1
View File
@@ -191,7 +191,7 @@ common_chat_templates_ptr common_chat_templates_init(
const std::string & eos_token_override = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
struct common_chat_params common_chat_templates_apply(
+12 -7
View File
@@ -91,6 +91,16 @@ lexer_result lexer::tokenize(const std::string & source) {
return str;
};
auto consume_numeric = [&]() -> std::string {
std::string num = consume_while(is_integer);
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
++pos; // Consume '.'
std::string frac = consume_while(is_integer);
num += "." + frac;
}
return num;
};
auto next_pos_is = [&](std::initializer_list<char> chars, size_t n = 1) -> bool {
if (pos + n >= src.size()) return false;
for (char c : chars) {
@@ -258,7 +268,7 @@ lexer_result lexer::tokenize(const std::string & source) {
++pos; // Consume the operator
// Check for numbers following the unary operator
std::string num = consume_while(is_integer);
std::string num = consume_numeric();
std::string value = std::string(1, ch) + num;
token::type t = num.empty() ? token::unary_operator : token::numeric_literal;
// JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str());
@@ -307,12 +317,7 @@ lexer_result lexer::tokenize(const std::string & source) {
// Numbers
if (is_integer(ch)) {
start_pos = pos;
std::string num = consume_while(is_integer);
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
++pos; // Consume '.'
std::string frac = consume_while(is_integer);
num += "." + frac;
}
std::string num = consume_numeric();
// JJ_DEBUG("consumed numeric literal: '%s'", num.c_str());
tokens.push_back({token::numeric_literal, num, start_pos});
continue;
+23 -11
View File
@@ -268,8 +268,7 @@ value binary_expression::execute_impl(context & ctx) {
// String in object
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
auto key = left_val->as_string().str();
auto & obj = right_val->as_object();
bool has_key = obj.find(key) != obj.end();
bool has_key = right_val->has_key(key);
if (op.value == "in") {
return mk_val<value_bool>(has_key);
} else if (op.value == "not in") {
@@ -464,7 +463,7 @@ value for_statement::execute_impl(context & ctx) {
std::vector<value> items;
if (is_val<value_object>(iterable_val)) {
JJ_DEBUG("%s", "For loop over object keys");
auto & obj = iterable_val->as_object();
auto & obj = iterable_val->as_ordered_object();
for (auto & p : obj) {
auto tuple = mk_val<value_array>();
if (iterable_val->val_obj.is_key_numeric) {
@@ -560,6 +559,7 @@ value for_statement::execute_impl(context & ctx) {
for (size_t i = 0; i < filtered_items.size(); i++) {
JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size());
value_object loop_obj = mk_val<value_object>();
loop_obj->has_builtins = false; // loop object has no builtins
loop_obj->insert("index", mk_val<value_int>(i + 1));
loop_obj->insert("index0", mk_val<value_int>(i));
loop_obj->insert("revindex", mk_val<value_int>(filtered_items.size() - i));
@@ -717,6 +717,7 @@ value member_expression::execute_impl(context & ctx) {
value property;
if (this->computed) {
// syntax: obj[expr]
JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str());
int64_t arr_size = 0;
@@ -745,10 +746,24 @@ value member_expression::execute_impl(context & ctx) {
property = this->property->execute(ctx);
}
} else {
// syntax: obj.prop
if (!is_stmt<identifier>(this->property)) {
throw std::runtime_error("Non-computed member property must be an identifier");
throw std::runtime_error("Static member property must be an identifier");
}
property = mk_val<value_string>(cast_stmt<identifier>(this->property)->val);
std::string prop = property->as_string().str();
JJ_DEBUG("Member expression, object type %s, static property '%s'", object->type().c_str(), prop.c_str());
// behavior of jinja2: obj having prop as a built-in function AND 'prop', as an object key,
// then obj.prop returns the built-in function, not the property value.
// while obj['prop'] returns the property value.
// example: {"obj": {"items": 123}} -> obj.items is the built-in function, obj['items'] is 123
value val = try_builtin_func(ctx, prop, object, true);
if (!is_val<value_undefined>(val)) {
return val;
}
// else, fallthrough to normal property access below
}
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
@@ -763,11 +778,8 @@ value member_expression::execute_impl(context & ctx) {
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
}
auto key = property->as_string().str();
auto & obj = object->as_object();
auto it = obj.find(key);
if (it != obj.end()) {
val = it->second;
} else {
val = object->at(key, val);
if (is_val<value_undefined>(val)) {
val = try_builtin_func(ctx, key, object, true);
}
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
@@ -793,7 +805,7 @@ value member_expression::execute_impl(context & ctx) {
} else if (is_val<value_string>(property)) {
auto key = property->as_string().str();
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
val = try_builtin_func(ctx, key, object);
val = try_builtin_func(ctx, key, object, true);
} else {
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
}
@@ -802,7 +814,7 @@ value member_expression::execute_impl(context & ctx) {
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
}
auto key = property->as_string().str();
val = try_builtin_func(ctx, key, object);
val = try_builtin_func(ctx, key, object, true);
}
if (ctx.is_get_stats && val && object && property) {
+3 -2
View File
@@ -56,6 +56,7 @@ struct context {
// src is optional, used for error reporting
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
env = mk_val<value_object>();
env->has_builtins = false; // context object has no builtins
env->insert("true", mk_val<value_bool>(true));
env->insert("True", mk_val<value_bool>(true));
env->insert("false", mk_val<value_bool>(false));
@@ -68,7 +69,7 @@ struct context {
context(const context & parent) : context() {
// inherit variables (for example, when entering a new scope)
auto & pvar = parent.env->as_object();
auto & pvar = parent.env->as_ordered_object();
for (const auto & pair : pvar) {
set_val(pair.first, pair.second);
}
@@ -265,7 +266,7 @@ struct comment_statement : public statement {
struct member_expression : public expression {
statement_ptr object;
statement_ptr property;
bool computed;
bool computed; // true if obj[expr] and false if obj.prop
member_expression(statement_ptr && object, statement_ptr && property, bool computed)
: object(std::move(object)), property(std::move(property)), computed(computed) {
+73 -54
View File
@@ -698,6 +698,7 @@ const func_builtins & value_bool_t::get_builtins() const {
bool val = args.get_pos(0)->as_bool();
return mk_val<value_string>(val ? "True" : "False");
}},
{"tojson", tojson},
};
return builtins;
}
@@ -775,19 +776,30 @@ const func_builtins & value_array_t::get_builtins() const {
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("join() first argument must be an array");
}
value val_delim = args.get_kwarg_or_pos("d", 1);
value val_attribute = args.get_kwarg_or_pos("attribute", 2);
if (!val_attribute->is_undefined()) {
throw not_implemented_exception("array attribute join not implemented");
}
value val_delim = args.get_kwarg_or_pos("d", 1);
value attribute = args.get_kwarg_or_pos("attribute", 2);
const auto & arr = args.get_pos(0)->as_array();
std::string delim = is_val<value_string>(val_delim) ? val_delim->as_string().str() : "";
const bool attr_is_int = is_val<value_int>(attribute);
if (!attribute->is_undefined() && !is_val<value_string>(attribute) && !attr_is_int) {
throw raised_exception("join() attribute must be string or integer");
}
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::string result;
for (size_t i = 0; i < arr.size(); ++i) {
if (!is_val<value_string>(arr[i]) && !is_val<value_int>(arr[i]) && !is_val<value_float>(arr[i])) {
value val_arr = arr[i];
if (!attribute->is_undefined()) {
if (attr_is_int && is_val<value_array>(val_arr)) {
val_arr = val_arr->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
val_arr = val_arr->at(attr_name);
}
}
if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
throw raised_exception("join() can only join arrays of strings or numerics");
}
result += arr[i]->as_string().str();
result += val_arr->as_string().str();
if (i < arr.size() - 1) {
result += delim;
}
@@ -802,26 +814,30 @@ const func_builtins & value_array_t::get_builtins() const {
}},
{"tojson", tojson},
{"map", [](const func_args & args) -> value {
args.ensure_count(2, 3);
args.ensure_count(2);
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("map: first argument must be an array");
}
value attribute = args.get_kwarg_or_pos("attribute", 1);
if (is_val<value_int>(attribute)) {
throw not_implemented_exception("map: integer attribute not implemented");
if (!is_val<value_kwarg>(args.get_args().at(1))) {
throw not_implemented_exception("map: filter-mapping not implemented");
}
if (!is_val<value_string>(attribute)) {
value attribute = args.get_kwarg_or_pos("attribute", 1);
const bool attr_is_int = is_val<value_int>(attribute);
if (!is_val<value_string>(attribute) && !attr_is_int) {
throw raised_exception("map: attribute must be string or integer");
}
std::string attr_name = attribute->as_string().str();
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->as_string().str();
value default_val = args.get_kwarg("default", mk_val<value_undefined>());
auto out = mk_val<value_array>();
auto arr = args.get_pos(0)->as_array();
for (const auto & item : arr) {
if (!is_val<value_object>(item)) {
throw raised_exception("map: item is not an object");
value attr_val;
if (attr_is_int) {
attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
} else {
attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
}
value attr_val = item->at(attr_name, default_val);
out->push_back(attr_val);
}
return out;
@@ -847,29 +863,35 @@ const func_builtins & value_array_t::get_builtins() const {
return arr_editable->pop_at(index);
}},
{"sort", [](const func_args & args) -> value {
args.ensure_count(1, 3);
args.ensure_count(1, 4);
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("sort: first argument must be an array");
}
bool reverse = args.get_kwarg("reverse", mk_val<value_undefined>())->as_bool();
value attribute = args.get_kwarg("attribute", mk_val<value_undefined>());
std::string attr = attribute->is_undefined() ? "" : attribute->as_string().str();
value val_reverse = args.get_kwarg_or_pos("reverse", 1);
value val_case = args.get_kwarg_or_pos("case_sensitive", 2);
value attribute = args.get_kwarg_or_pos("attribute", 3);
// FIXME: sorting is currently always case sensitive
//const bool case_sensitive = val_case->as_bool(); // undefined == false
const bool reverse = val_reverse->as_bool(); // undefined == false
const bool attr_is_int = is_val<value_int>(attribute);
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
value val_a = a;
value val_b = b;
if (!attribute->is_undefined()) {
if (!is_val<value_object>(a) || !is_val<value_object>(b)) {
throw raised_exception("sort: items are not objects");
if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
val_a = a->at(attr_int);
val_b = b->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
val_a = a->at(attr_name);
val_b = b->at(attr_name);
} else {
throw raised_exception("sort: unsupported object attribute comparison");
}
val_a = attr.empty() ? a : a->at(attr);
val_b = attr.empty() ? b : b->at(attr);
}
if (reverse) {
return value_compare(val_a, val_b, value_compare_op::gt);
} else {
return !value_compare(val_a, val_b, value_compare_op::gt);
}
return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
});
return mk_val<value_array>(arr);
}},
@@ -888,6 +910,11 @@ const func_builtins & value_array_t::get_builtins() const {
const func_builtins & value_object_t::get_builtins() const {
if (!has_builtins) {
static const func_builtins no_builtins = {};
return no_builtins;
}
static const func_builtins builtins = {
// {"default", default_value}, // cause issue with gpt-oss
{"get", [](const func_args & args) -> value {
@@ -902,18 +929,13 @@ const func_builtins & value_object_t::get_builtins() const {
if (args.count() == 3) {
default_val = args.get_pos(2);
}
const auto & obj = args.get_pos(0)->as_object();
const value obj = args.get_pos(0);
std::string key = args.get_pos(1)->as_string().str();
auto it = obj.find(key);
if (it != obj.end()) {
return it->second;
} else {
return default_val;
}
return obj->at(key, default_val);
}},
{"keys", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_object();
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
result->push_back(mk_val<value_string>(pair.first));
@@ -922,7 +944,7 @@ const func_builtins & value_object_t::get_builtins() const {
}},
{"values", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_object();
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
result->push_back(pair.second);
@@ -931,7 +953,7 @@ const func_builtins & value_object_t::get_builtins() const {
}},
{"items", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_object();
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
auto item = mk_val<value_array>();
@@ -945,7 +967,7 @@ const func_builtins & value_object_t::get_builtins() const {
{"string", tojson},
{"length", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_object();
const auto & obj = args.get_pos(0)->as_ordered_object();
return mk_val<value_int>(static_cast<int64_t>(obj.size()));
}},
{"tojson", [](const func_args & args) -> value {
@@ -958,21 +980,18 @@ const func_builtins & value_object_t::get_builtins() const {
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
value val_by = args.get_kwarg_or_pos("by", 2);
value val_reverse = args.get_kwarg_or_pos("reverse", 3);
// FIXME: sorting is case sensitive
// FIXME: sorting is currently always case sensitive
//const bool case_sensitive = val_case->as_bool(); // undefined == false
const bool reverse = val_reverse->as_bool(); // undefined == false
if (!val_by->is_undefined()) {
throw not_implemented_exception("dictsort by key not implemented");
}
if (reverse) {
throw not_implemented_exception("dictsort reverse not implemented");
}
value_t::map obj = val_input->val_obj; // copy
std::sort(obj.ordered.begin(), obj.ordered.end(), [&](const auto & a, const auto & b) {
return a.first < b.first;
const bool by_value = is_val<value_string>(val_by) && val_by->as_string().str() == "value" ? true : false;
auto result = mk_val<value_object>(val_input); // copy
std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) {
if (by_value) {
return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt);
} else {
return reverse ? a.first > b.first : a.first < b.first;
}
});
auto result = mk_val<value_object>();
result->val_obj = std::move(obj);
return result;
}},
{"join", [](const func_args &) -> value {
@@ -1169,7 +1188,7 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val,
}
oss << "]";
} else if (is_val<value_object>(val)) {
const auto & obj = val->val_obj.ordered; // IMPORTANT: need to keep exact order
const auto & obj = val->as_ordered_object(); // IMPORTANT: need to keep exact order
oss << "{";
if (!obj.empty()) {
oss << newline();
+31 -4
View File
@@ -146,7 +146,7 @@ struct value_t {
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
virtual const std::map<std::string, value> & as_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
virtual bool is_none() const { return false; }
virtual bool is_undefined() const { return false; }
@@ -154,6 +154,9 @@ struct value_t {
throw std::runtime_error("No builtins available for type " + type());
}
virtual bool has_key(const std::string & key) {
return val_obj.unordered.find(key) != val_obj.unordered.end();
}
virtual value & at(const std::string & key, value & default_val) {
auto it = val_obj.unordered.find(key);
if (it == val_obj.unordered.end()) {
@@ -168,8 +171,20 @@ struct value_t {
}
return val_obj.unordered.at(key);
}
virtual value & at(size_t index) {
if (index >= val_arr.size()) {
virtual value & at(int64_t index, value & default_val) {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
return default_val;
}
return val_arr[index];
}
virtual value & at(int64_t index) {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
}
return val_arr[index];
@@ -188,6 +203,9 @@ struct value_int_t : public value_t {
virtual int64_t as_int() const override { return val_int; }
virtual double as_float() const override { return static_cast<double>(val_int); }
virtual string as_string() const override { return std::to_string(val_int); }
virtual bool as_bool() const override {
return val_int != 0;
}
virtual const func_builtins & get_builtins() const override;
};
using value_int = std::shared_ptr<value_int_t>;
@@ -204,6 +222,9 @@ struct value_float_t : public value_t {
if (out.back() == '.') out.push_back('0'); // leave one zero if no decimals
return out;
}
virtual bool as_bool() const override {
return val_flt != 0.0;
}
virtual const func_builtins & get_builtins() const override;
};
using value_float = std::shared_ptr<value_float_t>;
@@ -286,6 +307,7 @@ using value_array = std::shared_ptr<value_array_t>;
struct value_object_t : public value_t {
bool has_builtins = true; // context and loop objects do not have builtins
value_object_t() = default;
value_object_t(value & v) {
val_obj = v->val_obj;
@@ -295,11 +317,16 @@ struct value_object_t : public value_t {
val_obj.insert(pair.first, pair.second);
}
}
value_object_t(const std::vector<std::pair<std::string, value>> & obj) {
for (const auto & pair : obj) {
val_obj.insert(pair.first, pair.second);
}
}
void insert(const std::string & key, const value & val) {
val_obj.insert(key, val);
}
virtual std::string type() const override { return "Object"; }
virtual const std::map<std::string, value> & as_object() const override { return val_obj.unordered; }
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const override { return val_obj.ordered; }
virtual bool as_bool() const override {
return !val_obj.unordered.empty();
}
+31 -2
View File
@@ -1078,6 +1078,9 @@ class TextModel(ModelBase):
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
# ref: https://huggingface.co/aari1995/German_Semantic_V3
res = "jina-v2-de"
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
res = "glm4"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@@ -7458,7 +7461,7 @@ class DeepseekModel(TextModel):
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
"YoutuForCausalLM",
"YoutuVLForConditionalGeneration"
"YoutuVLForConditionalGeneration",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@@ -8446,6 +8449,32 @@ class Glm4MoeModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("Glm4MoeLiteForCausalLM")
class Glm4MoeLiteModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
# copied from Glm4MoeModel
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
special_vocab.add_to_gguf(self.gguf_writer)
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(TextModel):
model_arch = gguf.MODEL_ARCH.CHATGLM
@@ -9183,7 +9212,7 @@ class NemotronHModel(GraniteHybridModel):
return [(mapped_name, reshaped_data)]
if name.endswith("mixer.norm.weight"):
reshaped_data = data_torch.reshape(8, 512)
reshaped_data = data_torch.reshape(self.n_group, -1)
mapped_name = self.map_tensor_name(name)
return [(mapped_name, reshaped_data)]
+1
View File
@@ -170,6 +170,7 @@ pre_computed_hashes = [
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
]
+1
View File
@@ -8,6 +8,7 @@
- [CMake Options](#cmake-options)
- [Android](#android)
- [Windows 11 Arm64](#windows-11-arm64)
- [Linux](#Linux)
- [Known Issue](#known-issues)
- [TODO](#todo)
+24 -24
View File
@@ -20,10 +20,10 @@ Legend:
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
@@ -34,20 +34,20 @@ Legend:
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | 🟡 | | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -61,9 +61,9 @@ Legend:
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ | | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ | | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
@@ -72,9 +72,10 @@ Legend:
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -82,39 +83,38 @@ Legend:
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ROPE | ❌ | | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | ❌ |
| SET_ROWS | ❌ | | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| SSM_CONV | ❌ | | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SSM_CONV | ❌ | | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SUM | ❌ | | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | | ❌ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
+15352 -4630
View File
File diff suppressed because it is too large Load Diff
+7683 -7584
View File
File diff suppressed because it is too large Load Diff
@@ -4,6 +4,7 @@ set -e
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
BUILD_DIR="${2:-"$BUILD_DIR"}"
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
@@ -13,6 +14,10 @@ if [ -z "$CONVERTED_MODEL" ]; then
exit 1
fi
cmake --build ../../build --target llama-debug -j8
if [ -z "$BUILD_DIR" ]; then
BUILD_DIR="../../build"
fi
../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
cmake --build ${BUILD_DIR} --target llama-debug -j8
${BUILD_DIR}/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
@@ -5,11 +5,16 @@ set -e
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
BUILD_DIR="${3:-"$BUILD_DIR"}"
if [ -z "$MODEL_TESTING_PROMPT"]; then
if [ -z "$MODEL_TESTING_PROMPT" ]; then
MODEL_TESTING_PROMPT="Hello, my name is"
fi
if [ -z "$BUILD_DIR" ]; then
BUILD_DIR="../../build"
fi
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
echo "Error: Model path must be provided either as:" >&2
@@ -21,6 +26,6 @@ fi
echo $CONVERTED_MODEL
echo $MODEL_TESTING_PROMPT
cmake --build ../../build --target llama-debug -j8
cmake --build ${BUILD_DIR} --target llama-debug -j8
../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
@@ -28,6 +28,7 @@ done
# First try command line argument, then environment variable
CONVERTED_MODEL="${CONVERTED_MODEL:-"$CONVERTED_EMBEDDING_MODEL"}"
BUILD_DIR="${BUILD_DIR:-"../../build"}"
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
@@ -50,5 +51,5 @@ fi
echo $CONVERTED_MODEL
cmake --build ../../build --target llama-debug -j8
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE
cmake --build ${BUILD_DIR} --target llama-debug -j8
${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE
+39 -7
View File
@@ -630,10 +630,11 @@ extern "C" {
// this tensor...
enum ggml_tensor_flag {
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed
};
enum ggml_tri_type {
@@ -2577,11 +2578,42 @@ extern "C" {
struct ggml_tensor * grad,
struct ggml_tensor * sgd_params); // alpha, weight decay
// build forward mutiple tensors and select one of them for computing
// this is useful for creating graphs that have constant topology but compute different things based on the input
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
//
// automatic differentiation
// nodes:
// | - build forward into the graph but do not compute
// c - build forward into the graph and compute
//
// | | ... c ... |
// | | ... c ... |
// | | ... c ... |
// [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx)
// c
// c
//
// example:
// struct ggml_tensor * curs[3];
//
// curs[0] = compute0(...);
// curs[1] = compute1(...);
// curs[2] = compute2(...);
//
// int idx = select_branch(some_input);
//
// struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx);
//
GGML_API struct ggml_tensor * ggml_build_forward_select(
struct ggml_cgraph * cgraph,
struct ggml_tensor ** tensors,
int n_tensors,
int idx);
GGML_API void ggml_build_forward_expand(
struct ggml_cgraph * cgraph,
struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(
struct ggml_context * ctx, // context for gradient computation
struct ggml_cgraph * cgraph,
@@ -2613,7 +2645,7 @@ extern "C" {
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
// dump the graph into a file using the dot format
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename);
// TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
+4 -20
View File
@@ -77,39 +77,23 @@
#include "ggml-zendnn.h"
#endif
// disable C++17 deprecation warning for std::codecvt_utf8
#if defined(__clang__)
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
namespace fs = std::filesystem;
static std::string path_str(const fs::path & path) {
std::string u8path;
try {
#if defined(__cpp_lib_char8_t)
// C++20 and later: u8string() returns std::u8string
std::u8string u8str = path.u8string();
u8path = std::string(reinterpret_cast<const char*>(u8str.c_str()));
const std::u8string u8str = path.u8string();
return std::string(reinterpret_cast<const char *>(u8str.data()), u8str.size());
#else
// C++17: u8string() returns std::string
u8path = path.u8string();
return path.u8string();
#endif
} catch (...) {
return std::string();
}
return u8path;
}
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
+3 -2
View File
@@ -874,9 +874,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
}
if (sched->debug > 1) {
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name,
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
@@ -1922,6 +1922,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
dst->view_offs = src->view_offs;
}
dst->op = src->op;
dst->flags = src->flags;
memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
ggml_set_name(dst, src->name);
+1 -1
View File
@@ -93,7 +93,7 @@ if (BLAS_FOUND)
endif()
target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES})
target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
target_include_directories(ggml-blas SYSTEM PRIVATE ${BLAS_INCLUDE_DIRS})
else()
message(FATAL_ERROR "BLAS not found, please refer to "
"https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
+4
View File
@@ -226,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
switch (node->op) {
case GGML_OP_MUL_MAT:
ggml_backend_blas_mul_mat(ctx, node);
+4
View File
@@ -2146,6 +2146,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+4
View File
@@ -2943,6 +2943,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
ggml_compute_forward(&params, node);
if (state->ith == 0 && cplan->abort_callback &&
+64 -38
View File
@@ -7,10 +7,9 @@
#include "unary-ops.h"
#include "vec.h"
#include <cfloat>
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <functional>
// ggml_compute_forward_dup
@@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw(
}
}
// ggml_compute_forward_pool_1d_sk_p0
static void ggml_compute_forward_pool_1d_sk_p0(
// ggml_compute_forward_pool_1d_ksp
static void ggml_compute_forward_pool_1d_ksp(
const ggml_compute_params * params,
const ggml_op_pool op,
const int k,
const int s,
const int p,
ggml_tensor * dst) {
const ggml_tensor * src = dst->src[0];
@@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
return;
}
const char * cdata = (const char *)src->data;
const char * const data_end = cdata + ggml_nbytes(src);
float * drow = (float *)dst->data;
const int64_t IW = src->ne[0];
const int64_t OW = dst->ne[0];
const int64_t rs = dst->ne[0];
const int64_t nr = ggml_nrows(src);
while (cdata < data_end) {
const void * srow = (const void *)cdata;
int j = 0;
for (int64_t i = 0; i < rs; ++i) {
for (int64_t ir = 0; ir < nr; ++ir) {
const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
for (int64_t ow = 0; ow < OW; ++ow) {
float res = 0;
switch (op) {
case GGML_OP_POOL_AVG: drow[i] = 0; break;
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
case GGML_OP_POOL_AVG: res = 0.0f; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
int count = 0;
const int base = (int) ow * s - p;
for (int ki = 0; ki < k; ++ki) {
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
switch (op) {
case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
const int j = base + ki;
if (j < 0 || j >= (int) IW) {
continue;
}
++j;
float v;
if (src->type == GGML_TYPE_F32) {
v = ((const float *) srow_bytes)[j];
} else {
v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
}
switch (op) {
case GGML_OP_POOL_AVG: res += v; break;
case GGML_OP_POOL_MAX: res = std::max(v, res); break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
++count;
}
switch (op) {
case GGML_OP_POOL_AVG: drow[i] /= k; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
}
cdata += src->nb[1];
drow += rs;
drow[ow] = res;
}
}
}
@@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d(
const int k0 = opts[1];
const int s0 = opts[2];
const int p0 = opts[3];
GGML_ASSERT(p0 == 0); // padding not supported
GGML_ASSERT(k0 == s0); // only s = k supported
ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
}
// ggml_compute_forward_pool_2d
@@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d(
}
const int32_t * opts = (const int32_t *)dst->op_params;
ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int k1 = opts[2];
@@ -7217,11 +7233,13 @@ void ggml_compute_forward_pool_2d(
while (cdata < data_end) {
for (int oy = 0; oy < py; ++oy) {
float * const drow = dplane + oy * px;
float * const out = drow;
for (int ox = 0; ox < px; ++ox) {
float * const out = drow + ox;
float res = 0;
switch (op) {
case GGML_OP_POOL_AVG: *out = 0; break;
case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
@@ -7229,24 +7247,32 @@ void ggml_compute_forward_pool_2d(
const int iy = offset1 + oy * s1;
for (int ky = 0; ky < k1; ++ky) {
if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
if (iy + ky < 0 || iy + ky >= src->ne[1]) {
continue;
}
const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
for (int kx = 0; kx < k0; ++kx) {
int j = ix + kx;
if (j < 0 || j >= src->ne[0]) continue;
if (j < 0 || j >= src->ne[0]) {
continue;
}
const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
switch (op) {
case GGML_OP_POOL_AVG: *out += srow_j; break;
case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
case GGML_OP_POOL_AVG: res += srow_j; break;
case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
}
}
switch (op) {
case GGML_OP_POOL_AVG: *out /= ka; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_AVG: res /= ka; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
}
out[ox] = res;
}
}
+1
View File
@@ -1123,6 +1123,7 @@ struct ggml_tensor_extra_gpu {
struct ggml_cuda_graph_node_properties {
void * node_address;
ggml_op node_op;
int32_t flags;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
+8
View File
@@ -2918,6 +2918,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
props->node_address = node->data;
props->node_op = node->op;
props->flags = node->flags;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
props->ne[i] = node->ne[i];
props->nb[i] = node->nb[i];
@@ -2961,6 +2962,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
return false;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
return false;
}
return true;
}
@@ -3378,6 +3383,9 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
// start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
+4
View File
@@ -2497,6 +2497,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
uint32_t flags = 0;
// skip quantizer if src1 is reused
+3
View File
@@ -611,6 +611,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in
if (node->op != ops[i]) {
return false;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return false;
}
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
return false;
}
+25
View File
@@ -94,6 +94,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
const char * pool_str = "undefined";
switch (op_pool) {
case GGML_OP_POOL_AVG: pool_str = "avg"; break;
case GGML_OP_POOL_MAX: pool_str = "max"; break;
default: GGML_ASSERT(false && "not implemented");
};
char base[256];
char name[256];
snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
snprintf(name, sizeof(name), "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
+1
View File
@@ -104,6 +104,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
+4 -8
View File
@@ -1044,10 +1044,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_POOL_1D:
return false;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
case GGML_OP_POOL_1D:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_POOL_2D:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_PAD:
@@ -1078,12 +1078,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
op->src[0]->ne[0] != 256) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek sizes
// TODO: disabled for now, until optmized
op->src[0]->ne[0] != 256 &&
op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {
+9
View File
@@ -928,6 +928,15 @@ typedef struct {
int64_t np;
} ggml_metal_kargs_pool_2d;
typedef struct {
int32_t k0;
int32_t s0;
int32_t p0;
int64_t IW;
int64_t OW;
int64_t np;
} ggml_metal_kargs_pool_1d;
typedef struct {
int64_t ne00;
uint64_t nb01;
+57 -1
View File
@@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_ABORT("unsupported op");
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return 1;
}
int n_fuse = 1;
// check if the current node can run concurrently with other nodes before it
@@ -432,6 +436,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_cpy(ctx, idx);
} break;
case GGML_OP_POOL_1D:
{
n_fuse = ggml_metal_op_pool_1d(ctx, idx);
} break;
case GGML_OP_POOL_2D:
{
n_fuse = ggml_metal_op_pool_2d(ctx, idx);
@@ -1622,6 +1630,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t * opts = op->op_params;
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
const int32_t k0 = opts[1];
const int32_t s0 = opts[2];
const int32_t p0 = opts[3];
const int64_t IW = op->src[0]->ne[0];
const int64_t OW = op->ne[0];
const int64_t np = ggml_nelements(op);
ggml_metal_kargs_pool_1d args_pool_1d = {
/* .k0 = */ k0,
/* .s0 = */ s0,
/* .p0 = */ p0,
/* .IW = */ IW,
/* .OW = */ OW,
/* .np = */ np
};
auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
const int ntg = (np + nth - 1) / nth;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
return 1;
}
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -2464,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
int32_t nsg = 4;
int32_t nsg = ne00 >= 512 ? 8 : 4;
const size_t smem = FATTN_SMEM(nsg);
+1
View File
@@ -61,6 +61,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
+76 -5
View File
@@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl(
constexpr short NC = (C/8)/NSG;
// note: do not unroll for large heads
#pragma unroll (DK <= 64 ? NC : 1)
for (short cc = 0; cc < NC; ++cc) {
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
if (DK % 16 != 0) {
@@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl(
k8x8_t mk[2];
q8x8_t mq[2];
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
// note: too much unroll can tank the performance for large heads
#pragma unroll (MIN(DK8/2, 4*NSG))
for (short i = 0; i < DK8/2; ++i) {
simdgroup_barrier(mem_flags::mem_none);
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
@@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl(
pv += 8*NS20;
}
} else {
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
constexpr short NC = (C/8)/2;
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
s8x8_t vs[2];
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
@@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS
@@ -9869,6 +9872,74 @@ kernel void kernel_pool_2d_avg_f32(
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_1d_max_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = -INFINITY;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
int j = base + ki;
if (j < 0 || j >= args.IW){
continue;
}
float v = src[src_off + j];
acc = max(acc, v);
}
dst[dst_off + ow] = acc;
}
kernel void kernel_pool_1d_avg_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = 0.0f;
int cnt = 0;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
const int j = base + ki;
if (j < 0 || j >= args.IW) {
continue;
}
acc += src[src_off + j];
cnt += 1;
}
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
+4
View File
@@ -3058,6 +3058,10 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
@@ -111,6 +111,10 @@ kernel void kernel_mul_mv_q6_K_f32(
int row = N_SIMDGROUP * r0 + get_sub_group_id();
if (row >= ne01) {
return;
}
int i12 = im%ne12;
int i13 = im/ne12;
+3
View File
@@ -4109,6 +4109,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
+4 -1
View File
@@ -12191,6 +12191,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
return false;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return false;
}
VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
ctx->semaphore_idx = 0;
@@ -13645,7 +13648,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
int last_node = cgraph->n_nodes - 1;
// If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) {
while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {
last_node -= 1;
}
+335 -36
View File
@@ -9,12 +9,28 @@
#define GGML_WEBGPU_F16_SIZE_BYTES 2
#define GGML_WEBGPU_F32_SIZE_BYTES 4
#define GGML_WEBGPU_I32_SIZE_BYTES 4
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
#define GGML_WEBGPU_KV_SEQ_PAD 256u
struct ggml_webgpu_flash_attn_shader_lib_context {
#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
void * decisions;
};
// Same hash combine function as in boost
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
/** FlashAttention */
struct ggml_webgpu_flash_attn_pipeline_key {
ggml_type kv_type;
uint32_t head_dim_qk;
uint32_t head_dim_v;
@@ -22,11 +38,35 @@ struct ggml_webgpu_flash_attn_shader_lib_context {
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
size_t wg_mem_limit_bytes;
uint32_t max_subgroup_size;
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap;
}
};
struct ggml_webgpu_flash_attn_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.kv_type);
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.kv_direct);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
return seed;
}
};
struct ggml_webgpu_flash_attn_shader_lib_context {
ggml_webgpu_flash_attn_pipeline_key key;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
size_t wg_mem_limit_bytes;
uint32_t max_subgroup_size;
};
struct ggml_webgpu_flash_attn_shader_decisions {
@@ -35,12 +75,6 @@ struct ggml_webgpu_flash_attn_shader_decisions {
uint32_t wg_size = 0;
};
struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
ggml_webgpu_flash_attn_shader_decisions decisions;
};
// This is exposed because it's necessary in supports_op
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
@@ -66,15 +100,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
}
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
const size_t limit_bytes = context.wg_mem_limit_bytes;
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
const size_t limit_bytes = context.wg_mem_limit_bytes;
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes =
(context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
if (!context.kv_direct) {
bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
if (!context.key.kv_direct) {
bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
}
if (context.has_mask) {
if (context.key.has_mask) {
bytes_per_kv += q_tile;
}
bytes_per_kv += q_tile;
@@ -90,7 +125,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
std::vector<std::string> defines;
std::string variant = "flash_attn";
switch (context.kv_type) {
switch (context.key.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
@@ -106,32 +141,31 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
variant += std::string("_") + ggml_type_name(context.kv_type);
variant += std::string("_") + ggml_type_name(context.key.kv_type);
if (context.has_mask) {
if (context.key.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
if (context.has_sinks) {
if (context.key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (context.uses_logit_softcap) {
if (context.key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
if (context.kv_direct) {
if (context.key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.head_dim_v);
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
// For now these are not part of the variant name
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
@@ -141,7 +175,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
if (context.kv_direct) {
if (context.key.kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
@@ -158,11 +192,276 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
result.decisions.q_tile = q_tile;
result.decisions.kv_tile = kv_tile;
result.decisions.wg_size = wg_size;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions();
decisions->q_tile = q_tile;
decisions->kv_tile = kv_tile;
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
/** Generic **/
struct ggml_webgpu_generic_shader_lib_context {
int vec4;
uint32_t max_wg_size;
};
struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_generic_shader_lib_context & context,
const std::string & base_variant) {
std::vector<std::string> defines;
std::string variant = base_variant;
if (context.vec4) {
defines.push_back("VEC4");
variant += "_vec";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
};
struct ggml_webgpu_pad_pipeline_key_hash {
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.circular);
return seed;
}
};
struct ggml_webgpu_pad_shader_lib_context {
ggml_webgpu_pad_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_pad_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "pad";
if (context.key.circular) {
defines.push_back("CIRCULAR");
variant += "_circular";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
/** Argsort **/
struct ggml_webgpu_argsort_shader_lib_context {
uint32_t max_wg_size;
size_t wg_mem_limit_bytes;
int32_t order;
};
struct ggml_webgpu_argsort_shader_decisions {
uint32_t wg_size = 0;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_argsort_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "argsort";
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
variant += std::string("_order") + std::to_string(context.order);
uint32_t wg_size = 1;
while (wg_size * 2 <= context.max_wg_size &&
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
wg_size *= 2;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_argsort_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "argsort_merge";
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
variant += std::string("_order") + std::to_string(context.order);
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
/** Set Rows **/
struct ggml_webgpu_set_rows_pipeline_key {
int dst_type;
int vec4;
int i64_idx;
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
}
};
struct ggml_webgpu_set_rows_pipeline_key_hash {
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.vec4);
ggml_webgpu_hash_combine(seed, key.i64_idx);
return seed;
}
};
struct ggml_webgpu_set_rows_shader_lib_context {
ggml_webgpu_set_rows_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_set_rows_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "set_rows";
switch (context.key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
variant += "_dstf32";
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
variant += "_dstf16";
break;
default:
GGML_ABORT("Unsupported dst type for set_rows shader");
}
if (context.key.vec4) {
defines.push_back("VEC4");
variant += "_vec";
}
if (context.key.i64_idx) {
defines.push_back("I64_IDX");
variant += "_i64idx";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
struct ggml_webgpu_unary_pipeline_key {
int type;
int op;
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
bool inplace;
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
}
};
struct ggml_webgpu_unary_pipeline_key_hash {
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.is_unary);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
struct ggml_webgpu_unary_shader_lib_context {
ggml_webgpu_unary_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_unary_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
ggml_op_name((ggml_op) context.key.op);
// Operation-specific behavior
defines.push_back(variant);
switch (context.key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for unary shader");
}
if (context.key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,72 @@
@group(0) @binding(0)
#ifdef VEC4
var<storage, read_write> src: array<vec4<f32>>;
#define VEC_SIZE 4
#else
var<storage, read_write> src: array<f32>;
#define VEC_SIZE 1
#endif
@group(0) @binding(1)
var<storage, read_write> dst: array<i32>;
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
ne0: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
const FLOAT_MIN: f32 = -1.0e9;
struct Pair {
value: f32,
index: i32
};
var<workgroup> shared_max: array<Pair, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let row_idx = params.offset_src + wid.x * params.ne0;
var local_pair = Pair(FLOAT_MIN, -1);
#ifdef VEC4
for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) {
let vec_val = src[row_idx / VEC_SIZE + col];
for (var v = 0u; v < VEC_SIZE; v++) {
let val = vec_val[v];
if (val >= local_pair.value) {
local_pair = Pair(val, i32(col * VEC_SIZE + v));
}
}
}
#else
for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
if (src[row_idx + col] >= local_pair.value) {
local_pair = Pair(src[row_idx + col], i32(col));
}
}
#endif
shared_max[lid.x] = local_pair;
workgroupBarrier();
var offset: u32 = WG_SIZE >> 1;
while (offset > 0) {
if (lid.x < offset) {
let a = shared_max[lid.x];
let b = shared_max[lid.x + offset];
if (b.value > a.value) {
shared_max[lid.x] = b;
} else if (b.value == a.value && b.index > a.index) {
shared_max[lid.x] = b;
}
}
workgroupBarrier();
offset >>= 1;
}
if (lid.x == 0u) {
dst[params.offset_dst + wid.x] = shared_max[0].index;
}
}
@@ -0,0 +1,106 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<i32>;
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// src/dst dimensions
src_ne0: u32,
ne1: u32,
ne2: u32,
ne0: u32,
top_k: u32,
npr: u32, // tiles per row
nrows: u32
};
@group(0) @binding(2)
var<uniform> params: Params;
var<workgroup> shmem_idx: array<u32, WG_SIZE>;
#if ORDER == 0
#define EXTREME_VALUE 1e30
#define SWAP_COMPARE_UP >
#define SWAP_COMPARE_DOWN <
#else
#define EXTREME_VALUE -1e30
#define SWAP_COMPARE_UP <
#define SWAP_COMPARE_DOWN >
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let linear = wid.x + wid.y * num_wg.x;
// guard against overprovisioned workgroups
if (linear >= params.npr * params.nrows) {
return;
}
let tile = linear % params.npr;
var row = linear / params.npr;
let i3 = row / (params.ne2 * params.ne1);
row = row % (params.ne2 * params.ne1);
let i2 = row / params.ne1;
let i1 = row % params.ne1;
let row_base = params.offset_src +
i1 * params.stride_src1 +
i2 * params.stride_src2 +
i3 * params.stride_src3;
let tile_base = tile * WG_SIZE;
let idx = tile_base + lid.x;
shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);
workgroupBarrier();
var k = 2u;
while (k <= WG_SIZE) {
var j = k >> 1;
while (j > 0) {
let ixj = lid.x ^ j;
if (ixj > lid.x) {
let dir_up = (lid.x & k) == 0;
let a_idx = shmem_idx[lid.x];
let b_idx = shmem_idx[ixj];
let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);
let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);
let should_swap = select(
(a_val SWAP_COMPARE_DOWN b_val),
(a_val SWAP_COMPARE_UP b_val),
dir_up);
if (should_swap) {
shmem_idx[lid.x] = b_idx;
shmem_idx[ixj] = a_idx;
}
}
workgroupBarrier();
j >>= 1;
}
k <<= 1;
}
let out_idx = tile * params.top_k + lid.x;
if (out_idx < params.ne0 && lid.x < params.top_k) {
let row_dst = params.offset_dst +
i1 * params.stride_dst1 +
i2 * params.stride_dst2 +
i3 * params.stride_dst3;
dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);
}
}
@@ -0,0 +1,134 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> idx_in: array<i32>;
@group(0) @binding(2)
var<storage, read_write> idx_out: array<i32>;
struct Params {
offset_src: u32, // in elements
offset_in: u32, // in elements
offset_out: u32, // in elements
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_idx1: u32,
stride_idx2: u32,
stride_idx3: u32,
stride_out1: u32,
stride_out2: u32,
stride_out3: u32,
ne0: u32,
ne1: u32,
ne2: u32,
top_k: u32,
len: u32,
nm: u32,
nrows: u32
};
@group(0) @binding(3)
var<uniform> params: Params;
fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
let a_val = src[row_base + u32(a_idx)];
let b_val = src[row_base + u32(b_idx)];
#if ORDER == 0
return a_val <= b_val;
#else
return a_val >= b_val;
#endif
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let linear = wid.x + wid.y * num_wg.x;
// guard against overprovisioned workgroups
if (linear >= params.nm * params.nrows) {
return;
}
let start = (linear % params.nm) * params.len * 2;
let len0 = min(params.len, params.ne0 - start);
let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
let len1 = min(params.len, rem1);
let total = len0 + len1;
let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
let k0 = lid.x * chunk;
let k1 = min(min(k0 + chunk, total), params.top_k);
// guard against overprovisioned threads
if (k0 >= params.top_k || k0 >= total) {
return;
}
var row = linear / params.nm;
let i3 = row / (params.ne2 * params.ne1);
row = row % (params.ne2 * params.ne1);
let i2 = row / params.ne1;
let i1 = row % params.ne1;
let row_src = params.offset_src +
i1 * params.stride_src1 +
i2 * params.stride_src2 +
i3 * params.stride_src3;
let row_in = params.offset_in +
i1 * params.stride_idx1 +
i2 * params.stride_idx2 +
i3 * params.stride_idx3;
let row_out = params.offset_out +
i1 * params.stride_out1 +
i2 * params.stride_out2 +
i3 * params.stride_out3;
var low: u32 = select(0, k0 - len1, k0 > len1);
var high: u32 = min(k0, len0);
while (low < high) {
let mid = (low + high) >> 1;
let idx0 = idx_in[row_in + start + mid];
let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
if (take_left(idx0, idx1, row_src)) {
low = mid + 1;
} else {
high = mid;
}
}
var i = low;
var j = k0 - i;
var k = k0;
while (k < k1) {
var take_l = false;
if (i >= len0) {
take_l = false;
} else if (j >= len1) {
take_l = true;
} else {
let idx0 = idx_in[row_in + start + i];
let idx1 = idx_in[row_in + start + params.len + j];
take_l = take_left(idx0, idx1, row_src);
}
let out_idx = select(
idx_in[row_in + start + params.len + j],
idx_in[row_in + start + i],
take_l);
idx_out[row_out + start + k] = out_idx;
i = select(i, i + 1, take_l);
j = select(j + 1, j, take_l);
k += 1;
}
}
@@ -7,6 +7,12 @@
"DST_TYPE": "f32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "i32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
@@ -0,0 +1,66 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
ne0: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
var<workgroup> shared_sum: array<f32, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let row_idx = params.offset_src + wid.x * params.ne0;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
var local_sum: f32 = 0.0;
for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
local_sum += src[row_idx + col];
}
shared_sum[lid.x] = local_sum;
workgroupBarrier();
// upsweep
var offset = 1u;
while (offset < WG_SIZE) {
let idx = (lid.x + 1) * offset * 2 - 1;
if (idx < WG_SIZE) {
shared_sum[idx] = shared_sum[idx] + shared_sum[idx - offset];
}
workgroupBarrier();
offset <<= 1;
}
// set last to 0 for exclusive sum
if (lid.x == 0) {
shared_sum[WG_SIZE - 1] = 0.0;
}
workgroupBarrier();
// downsweep
offset = WG_SIZE >> 1;
while (offset > 0) {
let idx = (lid.x + 1) * offset * 2 - 1;
if (idx < WG_SIZE) {
let t = shared_sum[idx - offset];
shared_sum[idx - offset] = shared_sum[idx];
shared_sum[idx] = shared_sum[idx] + t;
}
workgroupBarrier();
offset = offset >> 1;
}
// shared_sum[lid] is exclusive prefix sum up to this thread.
var running_sum = shared_sum[lid.x];
for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
running_sum += src[row_idx + col];
dst[params.offset_dst + wid.x * params.ne0 + col] = running_sum;
}
}
@@ -0,0 +1,86 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
src_ne3: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32,
dst_ne3: u32,
// Pad sizes (in elements)
lp0: u32,
rp0: u32,
lp1: u32,
rp1: u32,
lp2: u32,
rp2: u32,
lp3: u32,
rp3: u32,
};
@group(0) @binding(2)
var<uniform> params: Params;
fn wrap_around(idx: i32, n: u32) -> u32 {
return u32(idx + i32(n)) % n;
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0;
let i3 = i / dst_plane;
i = i % dst_plane;
let i2 = i / (params.dst_ne1 * params.dst_ne0);
i = i % (params.dst_ne1 * params.dst_ne0);
let i1 = i / params.dst_ne0;
let i0 = i % params.dst_ne0;
var value: f32 = 0.0;
#ifdef CIRCULAR
let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0);
let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1);
let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2);
let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3);
let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 +
ci2 * params.stride_src2 + ci3 * params.stride_src3;
value = src[params.offset_src + circular_src_idx];
#else
let is_src =
(i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) &&
(i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) &&
(i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) &&
(i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3);
if (is_src) {
let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 +
(i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3;
value = src[params.offset_src + src_idx];
}
#endif
dst[params.offset_dst + gid.x] = value;
}
@@ -1,41 +1,37 @@
#define(VARIANTS)
[
{
"SHADER_SUFFIX": "f16_vec",
"REPLS": {
"TYPE" : "vec4<f32>",
"DST_TYPE": "vec4<f16>",
"VEC_SIZE": 4
}
},
{
"SHADER_SUFFIX": "f16",
"REPLS": {
"TYPE" : "f32",
"DST_TYPE": "f16",
"VEC_SIZE": 1
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#ifdef DST_F32
#define DST_INNER_TYPE f32
#else
#define DST_INNER_TYPE f16
#endif
#ifdef VEC4
#define SRC_TYPE vec4<f32>
#define DST_TYPE vec4<DST_INNER_TYPE>
#define VEC_SIZE 4
#else
#define SRC_TYPE f32
#define DST_TYPE DST_INNER_TYPE
#define VEC_SIZE 1
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;
var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> idx: array<u32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
var<storage, read_write> dst: array<DST_TYPE>;
#ifdef I64_IDX
@group(0) @binding(3)
var<storage, read_write> error: atomic<u32>;
#define PARAMS_BINDING 4
#else
#define PARAMS_BINDING 3
#endif
struct Params {
offset_src: u32, // in elements
@@ -66,18 +62,17 @@ struct Params {
idx2: u32,
};
@group(0) @binding(4)
@group(0) @binding(PARAMS_BINDING)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) {
if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {
return;
}
// getting the row from gid
let elems_per_row = params.ne0 / {{VEC_SIZE}};
let elems_per_row = params.ne0 / VEC_SIZE;
var i = gid.x / elems_per_row;
let i_src3 = i / (params.ne2 * params.n_rows);
@@ -90,9 +85,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_idx1 = i_src2 % params.idx1;
let i_idx0 = i_src1;
#ifdef I64_IDX
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
let idx_high_val = idx[idx_high];
let idx_val = idx[idx_high];
let idx_low_val = idx[idx_high + 1];
if (idx_low_val != 0) {
@@ -100,13 +96,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
atomicStore(&error, 1);
return;
}
#else
let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
let idx_val = idx[idx_i];
#endif
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
let col_idx = (gid.x % elems_per_row);
dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]);
dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
}
#end(SHADER)
@@ -0,0 +1,55 @@
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
struct Params {
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
ne0: u32,
ne1: u32,
ne2: u32
};
@group(0) @binding(2)
var<uniform> params: Params;
var<workgroup> shared_sum: array<f32, WG_SIZE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
var i = wid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
var local_sum: f32 = 0.0;
for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
local_sum += src[i_src_row + col];
}
shared_sum[lid.x] = local_sum;
workgroupBarrier();
// reduce within workgroup
var offset: u32 = WG_SIZE >> 1;
while (offset > 0) {
if (lid.x < offset) {
shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset];
}
workgroupBarrier();
offset >>= 1;
}
if (lid.x == 0) {
dst[params.offset_dst + wid.x] = shared_sum[0];
}
}
@@ -0,0 +1,179 @@
#ifdef TYPE_F16
enable f16;
#define TYPE f16
#else
#define TYPE f32
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<TYPE>;
#ifndef INPLACE
@group(0) @binding(1)
var<storage, read_write> dst: array<TYPE>;
#define PARAMS_BINDING 2
#else
#define PARAMS_BINDING 1
#endif
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements)
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
// Logical shapes
ne0: u32,
ne1: u32,
ne2: u32,
#ifdef CLAMP
clamp_min: f32,
clamp_max: f32,
#endif
#ifdef FILL
fill_val: f32,
#endif
#ifdef XIELU
alpha_n: f32,
alpha_p: f32,
beta: f32,
eps: f32,
#endif
};
@group(0) @binding(PARAMS_BINDING)
var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
#ifdef ABS
let res = abs(src[params.offset_src + src_idx]);
#endif
#ifdef SGN
let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef NEG
let res = -src[params.offset_src + src_idx];
#endif
#ifdef STEP
let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0));
#endif
#ifdef TANH
let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913));
#endif
#ifdef RELU
let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef ELU
let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef HARDSIGMOID
let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
#endif
#ifdef SIGMOID
let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx]));
#endif
#ifdef SILU
let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
#endif
#ifdef EXP
let res = exp(src[params.offset_src + src_idx]);
#endif
#ifdef LOG
let res = TYPE(log(f32(src[params.offset_src + src_idx])));
#endif
#ifdef CLAMP
let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max));
#endif
#ifdef FILL
let res = TYPE(params.fill_val);
#endif
#ifdef HARDSWISH
let res = src[params.offset_src + src_idx] *
min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
#endif
#ifdef GELU
let res = 0.5 * src[params.offset_src + src_idx] *
(1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
(src[params.offset_src + src_idx] +
0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
-9.010913, 9.010913)));
#endif
#ifdef GELU_QUICK
let res = src[params.offset_src + src_idx] * 0.5 *
(1.0 + tanh(clamp(0.79788456 *
(src[params.offset_src + src_idx] +
0.044715 * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
-9.010913, 9.010913)));
#endif
#ifdef GELU_ERF
let res = 0.5 * src[params.offset_src + src_idx] *
(1.0 + tanh(clamp(0.79788456 *
(src[params.offset_src + src_idx] +
0.044715 * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
-9.010913, 9.010913)));
#endif
#ifdef XIELU
let res =
select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
src[params.offset_src + src_idx]) *
TYPE(params.alpha_n) +
TYPE(params.beta) * src[params.offset_src + src_idx],
TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] +
TYPE(params.beta) * src[params.offset_src + src_idx],
src[params.offset_src + src_idx] > 0.0);
#endif
#ifdef SOFTPLUS
let src_f32 = f32(src[params.offset_src + src_idx]);
let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
#endif
#ifdef EXPM1
let res = exp(src[params.offset_src + src_idx]) - 1.0;
#endif
#ifdef FLOOR
let res = floor(src[params.offset_src + src_idx]);
#endif
#ifdef CEIL
let res = ceil(src[params.offset_src + src_idx]);
#endif
#ifdef ROUND
let src_f32 = f32(src[params.offset_src + src_idx]);
let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0);
let res = TYPE(result);
#endif
#ifdef TRUNC
let res = trunc(src[params.offset_src + src_idx]);
#endif
#ifdef INPLACE
src[params.offset_src + src_idx] = res;
#else
dst[params.offset_dst + gid.x] = res;
#endif
}
@@ -1,483 +0,0 @@
#define(REPL_TEMPLATES)
{
"XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);",
"ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);",
"SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);",
"NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];",
"STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));",
"TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
"RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);",
"ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);",
"HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
"SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));",
"SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));",
"EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);",
"HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
"GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
"GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
"GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
"CEIL_FUNC": "{{MUTATE}}[dst_i] = ceil(src[src_i]);"
}
#end(REPL_TEMPLATES)
#define(VARIANTS)
[
{
"SHADER_NAME": "abs_f32",
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "abs_f16",
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "abs_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "abs_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sgn_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sgn_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sgn_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sgn_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "neg_f32",
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "neg_f16",
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "neg_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "neg_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "step_f32",
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "step_f16",
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "step_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "step_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "tanh_f32",
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "tanh_f16",
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "tanh_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "tanh_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "elu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "elu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "elu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "elu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "relu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "relu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "relu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "relu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sigmoid_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sigmoid_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sigmoid_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sigmoid_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "silu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "silu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "silu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "silu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "exp_f32",
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "exp_f16",
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "exp_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "exp_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_f32",
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_f16",
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardswish_f32",
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardswish_f16",
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardswish_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardswish_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "xielu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "xielu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "xielu_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "xielu_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "ceil_f32",
"REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "ceil_f16",
"REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "ceil_inplace_f32",
"REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "ceil_inplace_f16",
"REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
"DECLS": ["INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(INPLACE)
#decl(NOT_INPLACE)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NOT_INPLACE)
#end(DECLS)
#define(SHADER)
enable f16;
fn update(dst_i: u32, src_i: u32) {
{{FUNC}}
}
@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;
DECLS
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32,
{{EXT_PARAMS}}
};
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
let i2 = i / (params.src_ne1 * params.src_ne0);
i = i % (params.src_ne1 * params.src_ne0);
let i1 = i / params.src_ne0;
let i0 = i % params.src_ne0;
var j = gid.x;
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
let j2 = j / (params.dst_ne1 * params.dst_ne0);
j = j % (params.dst_ne1 * params.dst_ne0);
let j1 = j / params.dst_ne0;
let j0 = j % params.dst_ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
j2 * params.stride_dst2 + j3 * params.stride_dst3;
update(params.offset_dst + dst_idx, params.offset_src + src_idx);
}
#end(SHADER)
+4
View File
@@ -58,6 +58,10 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
bool ok = ggml_zdnn_compute_forward(ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: unsupported op %s (%s)\n",
+4
View File
@@ -211,6 +211,10 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
switch (node->op) {
case GGML_OP_MUL_MAT:
ggml_zendnn_compute_forward_mul_mat(ctx, node);
+57 -18
View File
@@ -3441,7 +3441,8 @@ struct ggml_tensor * ggml_cast(
result->op = GGML_OP_CPY;
result->src[0] = a;
result->src[1] = result;
result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some
// backends for consistency with ggml_cpy_impl() above
return result;
}
@@ -4838,6 +4839,8 @@ struct ggml_tensor * ggml_pool_1d(
a->ne[2],
a->ne[3],
};
GGML_ASSERT(ne[0] > 0);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
int32_t params[] = { op, k0, s0, p0 };
@@ -4868,6 +4871,9 @@ struct ggml_tensor * ggml_pool_2d(
a->ne[2],
a->ne[3],
};
GGML_ASSERT(ne[0] > 0);
GGML_ASSERT(ne[1] > 0);
result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
@@ -6720,20 +6726,35 @@ static void ggml_compute_backward(
GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
}
static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
// check if already visited
size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) {
if (node->op != GGML_OP_NONE && compute) {
node->flags |= GGML_TENSOR_FLAG_COMPUTE;
}
const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
// This is the first time we see this node in the current graph.
cgraph->visited_hash_set.keys[node_hash_pos] = node;
ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
cgraph->use_counts[node_hash_pos] = 0;
} else {
if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
// already visited
if (compute) {
// update the compute flag regardless
for (int i = 0; i < GGML_MAX_SRC; ++i) {
struct ggml_tensor * src = node->src[i];
if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) {
ggml_visit_parents_graph(cgraph, src, true);
}
}
}
return node_hash_pos;
}
// This is the first time we see this node in the current graph.
cgraph->visited_hash_set.keys[node_hash_pos] = node;
ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
cgraph->use_counts[node_hash_pos] = 0;
for (int i = 0; i < GGML_MAX_SRC; ++i) {
const int k =
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
@@ -6742,7 +6763,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor
struct ggml_tensor * src = node->src[k];
if (src) {
size_t src_hash_pos = ggml_visit_parents(cgraph, src);
const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute);
// Update the use count for this operand.
cgraph->use_counts[src_hash_pos]++;
@@ -6773,17 +6794,17 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor
return node_hash_pos;
}
static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) {
if (!expand) {
// TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
ggml_graph_clear(cgraph);
}
const int n0 = cgraph->n_nodes;
const int n_old = cgraph->n_nodes;
ggml_visit_parents(cgraph, tensor);
ggml_visit_parents_graph(cgraph, tensor, compute);
const int n_new = cgraph->n_nodes - n0;
const int n_new = cgraph->n_nodes - n_old;
GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
if (n_new > 0) {
@@ -6792,8 +6813,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten
}
}
struct ggml_tensor * ggml_build_forward_select(
struct ggml_cgraph * cgraph,
struct ggml_tensor ** tensors,
int n_tensors,
int idx) {
GGML_ASSERT(idx >= 0 && idx < n_tensors);
for (int i = 0; i < n_tensors; i++) {
ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false);
}
return tensors[idx];
}
void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
ggml_build_forward_impl(cgraph, tensor, true);
ggml_build_forward_impl(cgraph, tensor, true, true);
}
void ggml_build_backward_expand(
@@ -7224,6 +7259,10 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
return false;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
return false;
}
if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
continue;
}
@@ -7305,7 +7344,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node,
label);
}
void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) {
char color[16];
FILE * fp = ggml_fopen(filename, "w");
@@ -7326,7 +7365,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
snprintf(color, sizeof(color), "yellow");
} else if (grad) {
if (ggml_graph_find(gf, node)) {
if (ggml_graph_find(cgraph, node)) {
snprintf(color, sizeof(color), "green");
} else {
snprintf(color, sizeof(color), "lightblue");
-36
View File
@@ -200,42 +200,6 @@ uint32_t llama_hparams::n_layer_kv() const {
return res;
}
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
if (p0 < pos_chunk_start) {
return true;
}
} break;
case LLAMA_SWA_TYPE_SYMMETRIC:
{
const int32_t half_n_swa = (int32_t) n_swa / 2;
const int32_t pos_diff = p1 - p0;
// Mask if outside the symmetric window
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
return true;
}
} break;
}
return false;
}
bool llama_hparams::use_mrope() const {
return rope_sections[0] > 0 && rope_sections[1] > 0;
}
+38 -1
View File
@@ -3,6 +3,7 @@
#include "llama.h"
#include <array>
#include <cassert>
// bump if necessary
#define LLAMA_MAX_LAYERS 512
@@ -274,9 +275,45 @@ struct llama_hparams {
uint32_t n_layer_kv() const;
// note that this function uses different SWA parameters from those in the hparams
// note: inlined on purpose for performance reasons
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
{
} break;
case LLAMA_SWA_TYPE_STANDARD:
{
if (p1 - p0 >= (int32_t) n_swa) {
return true;
}
} break;
case LLAMA_SWA_TYPE_CHUNKED:
{
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
if (p0 < pos_chunk_start) {
return true;
}
} break;
case LLAMA_SWA_TYPE_SYMMETRIC:
{
const int32_t half_n_swa = (int32_t) n_swa / 2;
const int32_t pos_diff = p1 - p0;
// Mask if outside the symmetric window
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
return true;
}
} break;
}
return false;
}
bool use_mrope() const;
};
+212 -70
View File
@@ -852,7 +852,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
const llama_seq_id seq_id_cell = cells.seq_get(idx);
// SWA mask
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
can_use = true;
}
}
@@ -1237,6 +1237,197 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
}
}
struct args_set_input_kq_mask {
const llama_hparams & hparams;
const llama_ubatch * ubatch;
const std::vector<llama_kv_cells> & v_cells;
const std::vector<uint32_t> & seq_to_stream;
uint32_t n_swa;
llama_swa_type swa_type;
int64_t n_kv;
int64_t n_stream;
int64_t n_tps;
};
template<bool causal, bool swa, bool is_2d, bool alibi>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
//const auto & hparams = args.hparams;
const auto & ubatch = args.ubatch;
const auto & v_cells = args.v_cells;
const auto & seq_to_stream = args.seq_to_stream;
const uint32_t n_swa = args.n_swa;
const llama_swa_type swa_type = args.swa_type;
const int64_t n_kv = args.n_kv;
const int64_t n_stream = args.n_stream;
const int64_t n_tps = args.n_tps;
// the min position in the batch for each sequence
llama_pos seq_pos_min[LLAMA_MAX_SEQ];
std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX);
for (uint32_t i = 0; i < ubatch->n_tokens; ++i) {
const llama_seq_id seq_id = ubatch->seq_id[i][0];
seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]);
}
for (uint32_t s = 0; s < n_stream; ++s) {
// bookeeping of the KQ mask cells that could change for other tokens of the same sequence
std::unordered_map<llama_seq_id, uint32_t> seq_srct;
std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs;
for (uint32_t ii = 0; ii < n_tps; ++ii) {
const uint32_t i = s*n_tps + ii;
const llama_seq_id seq_id = ubatch->seq_id[i][0];
const auto & cells = v_cells.at(seq_to_stream[seq_id]);
llama_pos p0 = -1;
const llama_pos p1 = ubatch->pos[i];
// for M-RoPE
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
const uint64_t idst = n_kv*i;
// for tokens of the same sequence, the mask is mostly the same, so we can reuse it
// the only cells that could change are the ones that are with similar positions as the
// ones in the batch (i.e. due to causal masking, SWA, etc.)
// keep track of those cells and shortcut the loop to save time
// note: this optimization is not compatible with Alibi position encoding
// ref: https://github.com/ggml-org/llama.cpp/pull/18842
bool prev = false;
auto & idxs = seq_idxs[seq_id];
if (!alibi) {
if (seq_srct.find(seq_id) != seq_srct.end()) {
const uint32_t srct = seq_srct[seq_id];
const uint64_t idst_prev = n_kv*srct;
std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst);
prev = true;
} else {
idxs.clear();
idxs.reserve(ubatch->n_tokens + n_swa + 32);
seq_srct[seq_id] = i;
}
}
for (uint32_t jj = 0; jj < n_kv; ++jj) {
uint32_t j = jj;
// we have an exiting mask for this sequence -> update just seq_idxs
if (!alibi) {
if (prev) {
if (jj >= idxs.size()) {
break;
}
j = idxs[jj];
}
}
if (cells.is_empty(j)) {
goto skip;
}
// mask the token if not the same sequence
if (!cells.seq_has(j, seq_id)) {
goto skip;
}
p0 = cells.pos_get(j);
if (!alibi) {
if (!prev) {
// record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32
if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) {
idxs.push_back(j);
}
}
}
if (causal) {
// mask future tokens
if (p0 > p1) {
goto skip;
}
// M-RoPE causal mask
if (is_2d) {
if (p0 == p1) {
const auto & p0_ext = cells.ext_get(j);
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
goto skip;
}
}
}
}
// apply SWA if any
if (swa) {
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
goto skip;
}
}
if (alibi) {
data[idst + j] = -std::abs(p0 - p1);
} else {
data[idst + j] = 0.0f;
}
continue;
skip:
data[idst + j] = -INFINITY;
}
}
}
}
template<bool causal, bool swa, bool is_2d>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
const bool alibi = args.hparams.use_alibi;
if (alibi) {
set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data);
} else {
set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data);
}
}
template<bool causal, bool swa>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
const bool is_2d = args.ubatch->is_pos_2d();
if (is_2d) {
set_input_kq_mask_impl<causal, swa, true> (args, data);
} else {
set_input_kq_mask_impl<causal, swa, false>(args, data);
}
}
template<bool causal>
static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) {
const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE;
if (swa) {
set_input_kq_mask_impl<causal, true> (args, data);
} else {
set_input_kq_mask_impl<causal, false>(args, data);
}
}
void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const uint32_t n_tokens = ubatch->n_tokens;
@@ -1251,74 +1442,29 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
// n_tps == n_tokens_per_stream
const int64_t n_tps = n_tokens/n_stream;
std::fill(data, data + ggml_nelements(dst), -INFINITY);
//const int64_t t_start = ggml_time_us();
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
// Causal mask:
// xxx-------
// xxxx------
// xxxxx-----
// Non-causal mask:
// xxxxx-----
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
// TODO: optimize this section
for (uint32_t h = 0; h < 1; ++h) {
for (uint32_t s = 0; s < n_stream; ++s) {
for (uint32_t ii = 0; ii < n_tps; ++ii) {
const uint32_t i = s*n_tps + ii;
const args_set_input_kq_mask args = {
/*.hparams =*/ hparams,
/*.ubatch =*/ ubatch,
/*.v_cells =*/ v_cells,
/*.seq_to_stream =*/ seq_to_stream,
/*.n_swa =*/ n_swa,
/*.swa_type =*/ swa_type,
/*.n_kv =*/ n_kv,
/*.n_stream =*/ n_stream,
/*.n_tps =*/ n_tps,
};
const llama_seq_id seq_id = ubatch->seq_id[i][0];
const auto & cells = v_cells[seq_to_stream[seq_id]];
const llama_pos p1 = ubatch->pos[i];
// for M-RoPE
const bool is_2d = ubatch->is_pos_2d();
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii);
for (uint32_t j = 0; j < n_kv; ++j) {
if (cells.is_empty(j)) {
continue;
}
// mask the token if not the same sequence
if (!cells.seq_has(j, seq_id)) {
continue;
}
const llama_pos p0 = cells.pos_get(j);
// mask future tokens
if (causal_attn && p0 > p1) {
continue;
}
// M-RoPE causal mask
if (causal_attn && is_2d && p0 == p1) {
const auto & p0_ext = cells.ext_get(j);
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
continue;
}
}
// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
}
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}
}
}
if (causal_attn) {
set_input_kq_mask_impl<true> (args, data);
} else {
set_input_kq_mask_impl<false>(args, data);
}
//const int64_t t_end = ggml_time_us();
//LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0);
}
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
@@ -1483,10 +1629,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
return gf;
}
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
}
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
-2
View File
@@ -257,8 +257,6 @@ private:
size_t size_k_bytes() const;
size_t size_v_bytes() const;
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
+5 -1
View File
@@ -265,7 +265,8 @@ struct llama_file::impl {
continue; // Interrupted by signal, retry
}
// Fallback to std::fread in case the DMA controller cannot access the buffer
if (errno == EFAULT) {
if (errno == EFAULT || errno == EINVAL) {
LLAMA_LOG_WARN("%s: Falling back to buffered IO due to %s\n", __func__, strerror(errno));
auto curr_off = tell();
close(fd);
fd = -1;
@@ -384,6 +385,9 @@ int llama_file::file_id() const {
#ifdef _WIN32
return _fileno(pimpl->fp);
#else
if (pimpl->fd != -1) {
return pimpl->fd;
}
#if defined(fileno)
return fileno(pimpl->fp);
#else
+12 -6
View File
@@ -539,12 +539,18 @@ llama_model_loader::llama_model_loader(
files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io));
contexts.emplace_back(ctx);
use_direct_io = use_direct_io && files.back()->has_direct_io();
// Disable mmap in case Direct I/O is enabled and available
if (use_direct_io && use_mmap) {
use_mmap = false;
LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__);
if (use_mmap && use_direct_io) {
if (files.back()->has_direct_io()) {
// Disable mmap, as DirectIO is available
use_mmap = false;
LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__);
} else {
// Disable DirectIO and reopen file using std::fopen for mmap
use_direct_io = false;
files.pop_back();
files.emplace_back(new llama_file(fname.c_str(), "rb", false));
LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__);
}
}
// Save tensors data offset of the main file.
+1
View File
@@ -187,6 +187,7 @@ llama_build_and_test(test-chat-parser.cpp)
llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-jinja.cpp)
llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python)
llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(
+45
View File
@@ -4679,6 +4679,37 @@ struct test_pool2d : public test_case {
}
};
// GGML_OP_POOL1D
struct test_pool1d : public test_case {
enum ggml_op_pool pool_type;
const ggml_type type_input;
const std::array<int64_t, 4> ne_input;
const int k0;
const int s0;
const int p0;
std::string vars() override {
return VARS_TO_STR6(pool_type, type_input, ne_input, k0, s0, p0);
}
test_pool1d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
ggml_type type_input = GGML_TYPE_F32,
std::array<int64_t,4> ne_input = {10, 1, 1, 1},
int k0 = 3, int s0 = 3, int p0 = 0)
: pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), s0(s0), p0(p0) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
ggml_set_param(input);
ggml_set_name(input, "input");
ggml_tensor * out = ggml_pool_1d(ctx, input, pool_type, k0, s0, p0);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_CONV_TRANSPOSE_1D
struct test_conv_transpose_1d : public test_case {
const std::array<int64_t, 4> ne_input;
@@ -7058,6 +7089,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
for (ggml_type type_input : {GGML_TYPE_F32}) {
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
for (int k0 : {1, 3}) {
for (int s0 : {1, 2}) {
for (int p0 : {0, 1}) {
test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 10, 3, 2, 1 }, k0, s0, p0));
test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 11, 1, 3, 2 }, k0, s0, p0));
test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 128, 2, 1, 3 }, k0, s0, p0));
}
}
}
}
}
#if 0
// >4GB im2col destination. Too slow to run by default.
// Test cases taken from Wan2.1 T2V 1.3B.
+345 -9
View File
@@ -4,6 +4,7 @@
#include <cstdlib>
#include <nlohmann/json.hpp>
#include <sheredom/subprocess.h>
#include "jinja/runtime.h"
#include "jinja/parser.h"
@@ -31,12 +32,24 @@ static void test_array_methods(testing & t);
static void test_object_methods(testing & t);
static void test_fuzzing(testing & t);
static bool g_python_mode = false;
int main(int argc, char *argv[]) {
testing t(std::cout);
t.verbose = true;
if (argc >= 2) {
t.set_filter(argv[1]);
// usage: test-jinja [-py] [filter_regex]
// -py : enable python mode (use python jinja2 for rendering expected output)
// only use this for cross-checking, not for correctness
// note: the implementation of this flag is basic, only intented to be used by maintainers
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-py") {
g_python_mode = true;
} else {
t.set_filter(arg);
}
}
t.test("whitespace control", test_whitespace_control);
@@ -53,7 +66,9 @@ int main(int argc, char *argv[]) {
t.test("string methods", test_string_methods);
t.test("array methods", test_array_methods);
t.test("object methods", test_object_methods);
t.test("fuzzing", test_fuzzing);
if (!g_python_mode) {
t.test("fuzzing", test_fuzzing);
}
return t.summary();
}
@@ -176,6 +191,84 @@ static void test_conditionals(testing & t) {
json::object(),
"yes"
);
test_template(t, "is undefined falsy",
"{{ 'yes' if not y else 'no' }}",
json::object(),
"yes"
);
test_template(t, "is undefined attribute falsy",
"{{ 'yes' if not y.x else 'no' }}",
{{"y", true}},
"yes"
);
test_template(t, "is undefined key falsy",
"{{ 'yes' if not y['x'] else 'no' }}",
{{"y", {{}}}},
"yes"
);
test_template(t, "is empty array falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", json::array()}},
"yes"
);
test_template(t, "is empty object falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", json::object()}},
"yes"
);
test_template(t, "is empty string falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", ""}},
"yes"
);
test_template(t, "is 0 falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", 0}},
"yes"
);
test_template(t, "is 0.0 falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", 0.0}},
"yes"
);
test_template(t, "is non-empty array truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", json::array({""})}},
"yes"
);
test_template(t, "is non-empty object truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", {"x", false}}},
"yes"
);
test_template(t, "is non-empty string truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", "0"}},
"yes"
);
test_template(t, "is 1 truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", 1}},
"yes"
);
test_template(t, "is 1.0 truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", 1.0}},
"yes"
);
}
static void test_loops(testing & t) {
@@ -247,6 +340,12 @@ static void test_expressions(testing & t) {
"Bob"
);
test_template(t, "negative float (not dot notation)",
"{{ -1.0 }}",
json::object(),
"-1.0"
);
test_template(t, "bracket notation",
"{{ user['name'] }}",
{{"user", {{"name", "Bob"}}}},
@@ -383,6 +482,32 @@ static void test_filters(testing & t) {
"123"
);
test_template(t, "sort reverse",
"{% for i in items|sort(true) %}{{ i }}{% endfor %}",
{{"items", json::array({3, 1, 2})}},
"321"
);
test_template(t, "sort with attribute",
"{{ items|sort(attribute='name')|join(attribute='age') }}",
{{"items", json::array({
json({{"name", "c"}, {"age", 3}}),
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
})}},
"123"
);
test_template(t, "sort with numeric attribute",
"{{ items|sort(attribute=0)|join(attribute=1) }}",
{{"items", json::array({
json::array({3, "z"}),
json::array({1, "x"}),
json::array({2, "y"}),
})}},
"xyz"
);
test_template(t, "join",
"{{ items|join(', ') }}",
{{"items", json::array({"a", "b", "c"})}},
@@ -534,6 +659,66 @@ static void test_literals(testing & t) {
json::object(),
"1"
);
test_template(t, "integer|abs",
"{{ -42 | abs }}",
json::object(),
"42"
);
test_template(t, "integer|float",
"{{ 42 | float }}",
json::object(),
"42.0"
);
test_template(t, "integer|tojson",
"{{ 42 | tojson }}",
json::object(),
"42"
);
test_template(t, "float|abs",
"{{ -3.14 | abs }}",
json::object(),
"3.14"
);
test_template(t, "float|int",
"{{ 3.14 | int }}",
json::object(),
"3"
);
test_template(t, "float|tojson",
"{{ 3.14 | tojson }}",
json::object(),
"3.14"
);
test_template(t, "string|tojson",
"{{ 'hello' | tojson }}",
json::object(),
"\"hello\""
);
test_template(t, "boolean|int",
"{{ true | int }}",
json::object(),
"1"
);
test_template(t, "boolean|float",
"{{ true | float }}",
json::object(),
"1.0"
);
test_template(t, "boolean|tojson",
"{{ true | tojson }}",
json::object(),
"true"
);
}
static void test_comments(testing & t) {
@@ -934,7 +1119,17 @@ static void test_array_methods(testing & t) {
);
test_template(t, "array|join attribute",
"{{ arr|join(attribute=0) }}",
"{{ arr|join(attribute='age') }}",
{{"arr", json::array({
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
json({{"name", "c"}, {"age", 3}}),
})}},
"123"
);
test_template(t, "array|join numeric attribute",
"{{ arr|join(attribute=-1) }}",
{{"arr", json::array({json::array({1}), json::array({2}), json::array({3})})}},
"123"
);
@@ -957,8 +1152,8 @@ static void test_array_methods(testing & t) {
"a,b,c,d"
);
test_template(t, "array.map() with attribute",
"{% for v in arr.map('age') %}{{ v }} {% endfor %}",
test_template(t, "array|map with attribute",
"{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
{{"arr", json::array({
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
@@ -967,8 +1162,28 @@ static void test_array_methods(testing & t) {
"1 2 3 "
);
test_template(t, "array.map() with numeric attribute",
"{% for v in arr.map(0) %}{{ v }} {% endfor %}",
test_template(t, "array|map with attribute default",
"{% for v in arr|map(attribute='age', default=3) %}{{ v }} {% endfor %}",
{{"arr", json::array({
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
json({{"name", "c"}}),
})}},
"1 2 3 "
);
test_template(t, "array|map without attribute default",
"{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
{{"arr", json::array({
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
json({{"name", "c"}}),
})}},
"1 2 "
);
test_template(t, "array|map with numeric attribute",
"{% for v in arr|map(attribute=0) %}{{ v }} {% endfor %}",
{{"arr", json::array({
json::array({10, "x"}),
json::array({20, "y"}),
@@ -977,6 +1192,22 @@ static void test_array_methods(testing & t) {
"10 20 30 "
);
test_template(t, "array|map with negative attribute",
"{% for v in arr|map(attribute=-1) %}{{ v }} {% endfor %}",
{{"arr", json::array({
json::array({10, "x"}),
json::array({20, "y"}),
json::array({30, "z"}),
})}},
"x y z "
);
test_template(t, "array|map with filter",
"{{ arr|map('int')|sum }}",
{{"arr", json::array({"1", "2", "3"})}},
"6"
);
// not used by any chat templates
// test_template(t, "array.insert()",
// "{% set _ = arr.insert(1, 'x') %}{{ arr|join(',') }}",
@@ -1063,9 +1294,21 @@ static void test_object_methods(testing & t) {
{{"obj", {{"items", json::array({1, 2, 3})}}}},
"{\"items\": [1, 2, 3]}"
);
test_template(t, "object attribute and key access",
"{{ obj.keys()|join(',') }} vs {{ obj['keys'] }} vs {{ obj.test }}",
{{"obj", {{"keys", "value"}, {"test", "attr_value"}}}},
"keys,test vs value vs attr_value"
);
test_template(t, "env should not have object methods",
"{{ keys is undefined }} {{ obj.keys is defined }}",
{{"obj", {{"a", "b"}}}},
"True True"
);
}
static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
t.test(name, [&tmpl, &vars, &expect](testing & t) {
jinja::lexer lexer;
auto lexer_res = lexer.tokenize(tmpl);
@@ -1098,6 +1341,99 @@ static void test_template(testing & t, const std::string & name, const std::stri
});
}
// keep this in-sync with https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
// note: we use SandboxedEnvironment instead of ImmutableSandboxedEnvironment to allow usage of in-place array methods like append() and pop()
static std::string py_script = R"(
import jinja2
import jinja2.ext as jinja2_ext
import json
import sys
from datetime import datetime
from jinja2.sandbox import SandboxedEnvironment
tmpl = json.loads(sys.argv[1])
vars_json = json.loads(sys.argv[2])
env = SandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[jinja2_ext.loopcontrols],
)
def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)
env.filters["tojson"] = lambda x, ensure_ascii=False, indent=None, separators=None, sort_keys=False: json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
env.globals["raise_exception"] = raise_exception
template = env.from_string(tmpl)
result = template.render(**vars_json)
print(result, end='')
)";
static void test_template_py(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
t.test(name, [&tmpl, &vars, &expect](testing & t) {
// Prepare arguments
std::string tmpl_json = json(tmpl).dump();
std::string vars_json = vars.dump();
#ifdef _WIN32
const char * python_executable = "python.exe";
#else
const char * python_executable = "python3";
#endif
const char * command_line[] = {python_executable, "-c", py_script.c_str(), tmpl_json.c_str(), vars_json.c_str(), NULL};
struct subprocess_s subprocess;
int options = subprocess_option_combined_stdout_stderr
| subprocess_option_no_window
| subprocess_option_inherit_environment
| subprocess_option_search_user_path;
int result = subprocess_create(command_line, options, &subprocess);
if (result != 0) {
t.log("Failed to create subprocess, error code: " + std::to_string(result));
t.assert_true("subprocess creation", false);
return;
}
// Read output
std::string output;
char buffer[1024];
FILE * p_stdout = subprocess_stdout(&subprocess);
while (fgets(buffer, sizeof(buffer), p_stdout)) {
output += buffer;
}
int process_return;
subprocess_join(&subprocess, &process_return);
subprocess_destroy(&subprocess);
if (process_return != 0) {
t.log("Python script failed with exit code: " + std::to_string(process_return));
t.log("Output: " + output);
t.assert_true("python execution", false);
return;
}
if (!t.assert_true("Template render mismatch", expect == output)) {
t.log("Template: " + json(tmpl).dump());
t.log("Expected: " + json(expect).dump());
t.log("Python : " + json(output).dump());
}
});
}
static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
if (g_python_mode) {
test_template_py(t, name, tmpl, vars, expect);
} else {
test_template_cpp(t, name, tmpl, vars, expect);
}
}
//
// fuzz tests to ensure no crashes occur on malformed inputs
//
+27 -5
View File
@@ -71,14 +71,16 @@ struct cli_context {
std::string generate_completion(result_timings & out_timings) {
server_response_reader rd = ctx_server.get_response_reader();
auto formatted = format_chat();
{
// TODO: reduce some copies here in the future
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
task.id = rd.get_new_id();
task.index = 0;
task.params = defaults; // copy
task.cli_input = messages; // copy
task.cli_files = input_files; // copy
task.id = rd.get_new_id();
task.index = 0;
task.params = defaults; // copy
task.cli_prompt = formatted.prompt; // copy
task.cli_files = input_files; // copy
task.cli = true;
rd.post_task({std::move(task)});
}
@@ -156,6 +158,26 @@ struct cli_context {
return content;
}
}
common_chat_params format_chat() {
auto meta = ctx_server.get_meta();
auto & chat_params = meta.chat_params;
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = {}; // TODO
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
inputs.json_schema = ""; // TODO
inputs.grammar = ""; // TODO
inputs.use_jinja = chat_params.use_jinja;
inputs.parallel_tool_calls = false;
inputs.add_generation_prompt = true;
inputs.reasoning_format = chat_params.reasoning_format;
inputs.enable_thinking = chat_params.enable_thinking;
// Apply chat template to the list of messages
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
}
};
int main(int argc, char ** argv) {
-1
View File
@@ -12,7 +12,6 @@ ggml_cgraph * clip_graph_conformer::build() {
ggml_build_forward_expand(gf, pos_emb);
ggml_tensor * inp = build_inp_raw(1);
cb(inp, "input", -1);
auto * cur = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+2 -2
View File
@@ -831,7 +831,7 @@ static void handle_media(
// used by /chat/completions endpoint
json oaicompat_chat_params_parse(
json & body, /* openai api json semantics */
const oaicompat_parser_options & opt,
const server_chat_params & opt,
std::vector<raw_buffer> & out_files)
{
json llama_params;
@@ -1012,7 +1012,7 @@ json oaicompat_chat_params_parse(
}
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
auto chat_params = common_chat_templates_apply(opt.tmpls.get(), inputs);
/* Append assistant prefilled message */
if (prefill_assistant_message) {
+7 -7
View File
@@ -274,25 +274,25 @@ std::vector<server_tokens> tokenize_input_prompts(
// OAI utils
//
// used by /completions endpoint
json oaicompat_completion_params_parse(const json & body);
struct oaicompat_parser_options {
struct server_chat_params {
bool use_jinja;
bool prefill_assistant;
common_reasoning_format reasoning_format;
std::map<std::string,std::string> chat_template_kwargs;
common_chat_templates * tmpls;
std::map<std::string, std::string> chat_template_kwargs; // mapping key --> json value
common_chat_templates_ptr tmpls;
bool allow_image;
bool allow_audio;
bool enable_thinking = true;
std::string media_path;
};
// used by /completions endpoint
json oaicompat_completion_params_parse(const json & body);
// used by /chat/completions endpoint
json oaicompat_chat_params_parse(
json & body, /* openai api json semantics */
const oaicompat_parser_options & opt,
const server_chat_params & opt,
std::vector<raw_buffer> & out_files);
// convert Anthropic Messages API format to OpenAI Chat Completions API format
+70 -81
View File
@@ -534,8 +534,8 @@ public:
server_queue queue_tasks;
server_response queue_results;
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;
// note: chat_params must not be refreshed upon existing sleeping state
server_chat_params chat_params;
~server_context_impl() {
if (!sleeping) {
@@ -688,15 +688,6 @@ private:
llama_init_dft->free_context();
}
chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs);
} catch (const std::exception & e) {
SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_init(model, "chatml");
}
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
if (!is_resume) {
@@ -845,30 +836,6 @@ private:
model_name = model_path.filename().string();
}
// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
SRV_INF("thinking = %d\n", enable_thinking);
oai_parser_opt = {
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* chat_template_kwargs */ params_base.default_template_kwargs,
/* common_chat_templates */ chat_templates.get(),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
/* media_path */ params_base.media_path,
};
// print sample chat example to make it clear which template is used
// @ngxson modern templates are too long, spam the logs; printing the example is enough
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
// common_chat_templates_source(chat_templates.get()),
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
if (!is_resume) {
return init();
}
@@ -907,6 +874,42 @@ private:
}
}
// populate chat template params
{
common_chat_templates_ptr chat_templates;
try {
chat_templates = common_chat_templates_init(model, params_base.chat_template);
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
} catch (const std::exception & e) {
SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what());
SRV_ERR("%s: please consider disabling jinja via --no-jinja, or use a custom chat template via --chat-template\n", __func__);
SRV_ERR("%s: for example: --no-jinja --chat-template chatml\n", __func__);
return false;
}
// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
chat_params = {
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* chat_template_kwargs */ params_base.default_template_kwargs,
/* tmpls */ std::move(chat_templates),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
/* media_path */ params_base.media_path,
};
}
return true;
}
@@ -1326,11 +1329,12 @@ private:
}
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
const size_t n_probs = slot.task->params.sampling.n_probs;
const size_t n_probs_request = slot.task->params.sampling.n_probs;
if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
const size_t max_probs = cur_p->size;
const size_t n_probs = std::min(max_probs, n_probs_request);
// set probability for sampled token
for (size_t i = 0; i < max_probs; i++) {
@@ -1341,8 +1345,8 @@ private:
}
// set probability for top n_probs tokens
result.probs.reserve(max_probs);
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
result.probs.reserve(n_probs);
for (size_t i = 0; i < n_probs; i++) {
result.probs.push_back({
cur_p->data[i].id,
common_token_to_piece(ctx, cur_p->data[i].id, special),
@@ -1352,9 +1356,11 @@ private:
} else {
// TODO: optimize this with min-p optimization
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
const size_t max_probs = cur.size();
const size_t n_probs = std::min(max_probs, n_probs_request);
// set probability for sampled token
for (size_t i = 0; i < cur.size(); i++) {
for (size_t i = 0; i < max_probs; i++) {
// set probability for sampled token
if (cur[i].id == result.tok) {
result.prob = cur[i].p;
@@ -1364,7 +1370,7 @@ private:
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
for (size_t i = 0; i < n_probs; i++) {
result.probs.push_back({
cur[i].id,
common_token_to_piece(ctx, cur[i].id, special),
@@ -1585,32 +1591,14 @@ private:
// tokenize the input if it's set by CLI, return false on error
bool tokenize_cli_input(server_task & task) {
GGML_ASSERT(task.cli_input != nullptr);
try {
auto & opt = oai_parser_opt;
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(task.cli_input);
inputs.tools = {}; // TODO
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
inputs.json_schema = ""; // TODO
inputs.grammar = ""; // TODO
inputs.use_jinja = opt.use_jinja;
inputs.parallel_tool_calls = false;
inputs.add_generation_prompt = true;
inputs.reasoning_format = opt.reasoning_format;
inputs.enable_thinking = opt.enable_thinking;
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
// tokenize the resulting prompt
auto & prompt = chat_params.prompt;
auto & prompt = task.cli_prompt;
if (mctx != nullptr) {
task.tokens = process_mtmd_prompt(mctx, prompt, task.cli_files);
} else {
task.tokens = std::move(tokenize_input_prompts(vocab, mctx, prompt, true, true)[0]);
}
task.cli_input.clear();
task.cli_prompt.clear();
task.cli_files.clear();
} catch (const std::exception & e) {
send_error(task, std::string("Failed to format input: ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
@@ -1686,7 +1674,7 @@ private:
{
// special case: if input is provided via CLI, tokenize it first
// otherwise, no need to tokenize as it's already done inside the HTTP thread
if (task.cli_input != nullptr) {
if (task.cli) {
if (!tokenize_cli_input(task)) {
break;
}
@@ -2898,8 +2886,6 @@ server_response_reader server_context::get_response_reader() {
}
server_context_meta server_context::get_meta() const {
auto tool_use_src = common_chat_templates_source(impl->chat_templates.get(), "tool_use");
auto bos_id = llama_vocab_bos(impl->vocab);
auto eos_id = llama_vocab_eos(impl->vocab);
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : "";
@@ -2910,14 +2896,13 @@ server_context_meta server_context::get_meta() const {
/* model_name */ impl->model_name,
/* model_path */ impl->params_base.model.path,
/* has_mtmd */ impl->mctx != nullptr,
/* has_inp_image */ impl->oai_parser_opt.allow_image,
/* has_inp_audio */ impl->oai_parser_opt.allow_audio,
/* has_inp_image */ impl->chat_params.allow_image,
/* has_inp_audio */ impl->chat_params.allow_audio,
/* json_webui_settings */ impl->json_webui_settings,
/* slot_n_ctx */ impl->get_slot_n_ctx(),
/* pooling_type */ llama_pooling_type(impl->ctx),
/* chat_template */ common_chat_templates_source(impl->chat_templates.get()),
/* chat_template_tool_use */ tool_use_src ? tool_use_src : "",
/* chat_params */ impl->chat_params,
/* bos_token_str */ bos_token_str,
/* eos_token_str */ eos_token_str,
@@ -3199,8 +3184,8 @@ void server_routes::init_routes() {
// this endpoint can be accessed during sleeping
// the next LOC is to avoid someone accidentally use ctx_server
bool server_ctx; // do NOT delete this line
GGML_UNUSED(server_ctx);
bool ctx_server; // do NOT delete this line
GGML_UNUSED(ctx_server);
res->ok({{"status", "ok"}});
return res;
@@ -3390,8 +3375,8 @@ void server_routes::init_routes() {
// this endpoint can be accessed during sleeping
// the next LOC is to avoid someone accidentally use ctx_server
bool server_ctx; // do NOT delete this line
GGML_UNUSED(server_ctx);
bool ctx_server; // do NOT delete this line
GGML_UNUSED(ctx_server);
task_params tparams;
tparams.sampling = params.sampling;
@@ -3400,6 +3385,9 @@ void server_routes::init_routes() {
{ "n_ctx", meta->slot_n_ctx },
};
std::string tmpl_default = common_chat_templates_source(meta->chat_params.tmpls.get(), "");
std::string tmpl_tools = common_chat_templates_source(meta->chat_params.tmpls.get(), "tool_use");
json props = {
{ "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", params.n_parallel },
@@ -3414,15 +3402,15 @@ void server_routes::init_routes() {
{ "endpoint_metrics", params.endpoint_metrics },
{ "webui", params.webui },
{ "webui_settings", meta->json_webui_settings },
{ "chat_template", meta->chat_template },
{ "chat_template", tmpl_default },
{ "bos_token", meta->bos_token_str },
{ "eos_token", meta->eos_token_str },
{ "build_info", meta->build_info },
{ "is_sleeping", queue_tasks.is_sleeping() },
};
if (params.use_jinja) {
if (!meta->chat_template_tool_use.empty()) {
props["chat_template_tool_use"] = meta->chat_template_tool_use;
if (!tmpl_tools.empty()) {
props["chat_template_tool_use"] = tmpl_tools;
}
}
res->ok(props);
@@ -3443,6 +3431,7 @@ void server_routes::init_routes() {
this->get_api_show = [this](const server_http_req &) {
auto res = create_response();
std::string tmpl_default = common_chat_templates_source(meta->chat_params.tmpls.get(), "");
json data = {
{
"model_info", {
@@ -3451,7 +3440,7 @@ void server_routes::init_routes() {
},
{"modelfile", ""},
{"parameters", ""},
{"template", meta->chat_template},
{"template", tmpl_default},
{"details", {
{"parent_model", ""},
{"format", "gguf"},
@@ -3576,7 +3565,7 @@ void server_routes::init_routes() {
json body = json::parse(req.body);
json body_parsed = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
meta->chat_params,
files);
return handle_completions_impl(
req,
@@ -3592,7 +3581,7 @@ void server_routes::init_routes() {
json body = convert_anthropic_to_oai(json::parse(req.body));
json body_parsed = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
meta->chat_params,
files);
return handle_completions_impl(
req,
@@ -3608,7 +3597,7 @@ void server_routes::init_routes() {
json body = convert_anthropic_to_oai(json::parse(req.body));
json body_parsed = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
meta->chat_params,
files);
json prompt = body_parsed.at("prompt");
@@ -3624,7 +3613,7 @@ void server_routes::init_routes() {
json body = json::parse(req.body);
json data = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
meta->chat_params,
files);
res->ok({{ "prompt", std::move(data.at("prompt")) }});
return res;
@@ -3635,8 +3624,8 @@ void server_routes::init_routes() {
// this endpoint can be accessed during sleeping
// the next LOC is to avoid someone accidentally use ctx_server
bool server_ctx; // do NOT delete this line
GGML_UNUSED(server_ctx);
bool ctx_server; // do NOT delete this line
GGML_UNUSED(ctx_server);
json models = {
{"models", {
+2 -3
View File
@@ -20,9 +20,8 @@ struct server_context_meta {
int slot_n_ctx;
enum llama_pooling_type pooling_type;
// chat template
std::string chat_template;
std::string chat_template_tool_use;
// chat params
server_chat_params & chat_params;
// tokens
std::string bos_token_str;
+4 -2
View File
@@ -130,8 +130,10 @@ struct server_task {
task_params params;
server_tokens tokens;
// only used by CLI, this delegates the tokenization to the server
json cli_input = nullptr;
// only used by CLI, this allow tokenizing CLI inputs on server side
// we need this because mtmd_context and vocab are not accessible outside of server_context
bool cli = false;
std::string cli_prompt;
std::vector<raw_buffer> cli_files;
server_task_type type;