Compare commits

...

20 Commits

Author SHA1 Message Date
Piotr Wilkin (ilintar) 0ec191e1d7 vocab: add gemma4 tokenizer tests, fix edge case (#21534)
* YATF (Yet Another Tokenizer Fix) for Gemma 4. With tests!
* Remove unnecessary hash  from update script.
* minor: move constant
2026-04-09 11:41:14 +02:00
Kwa Jie Hao 243532e556 jinja : support ensure_ascii=true, string repetition and int/float self-filtering (#21623)
* feat: jinja engine improvements for reka-edge

Port three Jinja engine improvements needed for the reka-edge model:
1. Python-style string repetition ("ab" * 3 → "ababab")
2. ensure_ascii=true support for tojson filter (escapes non-ASCII to \uXXXX)
3. int() builtin on value_int_t (identity, needed for Reka Edge template)

* fix: escape invalid utf8 bytes when ensure_ascii=true

The json_ensure_ascii_preserving_format function does not correctly
handle an edge case where if UTF-8 parsing fails, it adds the non-ascii
character back to the output as a raw byte.

This commit fixes that by adding the unicode standard replacement
character \\ufffd to the output instead. This is the standard behavior
for various programming languages like Python, Rust, Go, etc.

* chore: address PR comments

1. Add todo comment for supporting string repetition for array/tuples
2. Add support for float identity operation
3. Move invalid ascii test case to test_fuzzing

* chore: accept suggestion for common/jinja/value.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-04-09 11:28:33 +02:00
Georgi Gerganov 5e9c635463 metal : add missing mm-id specializations for q1_0 (#21662) 2026-04-09 10:54:00 +03:00
Aleksander Grygier 9949ad08f6 fix: Model Selector choice sync (#21628) 2026-04-09 09:46:27 +02:00
AUTOMATIC1111 3ee9da0e4f server : fix grammar commandline args (#21543)
Co-authored-by: AUTOMATIC <->
2026-04-09 10:16:54 +03:00
Aleksander Grygier 75511a8d7e webui: Add option to pre-encode conversation for faster next turns (#21034) 2026-04-09 09:10:18 +02:00
Akarshan Biswas b54cb2e3d0 sycl : add flash-attn support for head size 512 (#21654)
* sycl : add flash-attn support for head size 512

This patch extends the SYCL Flash Attention implementation to support head sizes (DKQ/DV) of 512.

Changes:
- Added DKQ/DV 512 cases to both tile and vector Flash Attention kernels.
- Updated kernel selection logic to allow vector kernels for head sizes up to 512 (previously 256).
- Removed unused/redundant AMD and RDNA-specific configuration functions in `fattn-tile.hpp`.
- Refactored `ggml_backend_sycl_buffer_init_tensor` to use a switch statement for clearer tensor extra buffer initialization.
- Added necessary template instances for the new 512 head size across various quantization types.

* remove defunct mxfp4 reorder from setting buffer type
2026-04-09 09:36:48 +03:00
Marxist-Leninist 8a65a7a8ee ci: drop v5 all: composition from labeler.yml (#21627)
actions/labeler@v6 removed the `all:` / `any:` composition keys.
The `server/webui` and `server` entries used `all:` to combine
`any-glob-to-any-file` with negated `all-globs-to-all-files`,
which now errors on every PR with:

    Unknown config options were under "changed-files": all

Flatten both entries to a single `any-glob-to-any-file`. PRs
touching both webui and other server files will now receive both
labels instead of only `server/webui`.

Co-authored-by: Marxist-Leninist <noreply@users.noreply.github.com>
2026-04-09 08:20:19 +02:00
Ruben Ortlam 8a132faaa0 vulkan: unify type macros to use Vx instead of _VECx (#21605) 2026-04-09 07:31:51 +02:00
Adrien Gallouët 4293919068 common : skip non-primary GGUF split files when selecting model (#21633)
We should not assume files are listed in order.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-04-09 07:28:06 +02:00
Aman Gupta d12cc3d1ca CUDA: also store node->src->data ptrs for equality check (#21635)
* CUDA: also store node->src->data ptrs for equality check

* address review comments
2026-04-09 01:01:56 +08:00
RealOrko 2dcb7f74ed fix: free ctx_copy in ggml_opt_free to plug per-training-session leak (#21592)
* fix: free ctx_copy in ggml_opt_free to plug per-training-session leak

ggml_opt_alloc populates opt_ctx->ctx_copy via a free+init pair every
time the allocated graph shape changes. The last ctx_copy from the
final ggml_opt_alloc call survives until ggml_opt_free is invoked,
but ggml_opt_free was only freeing ctx_static and ctx_cpu, never
ctx_copy. Each opt_ctx lifetime therefore leaks the final per-batch
context — ~900 KB for a typical GNN training session in
sindarin-pkg-tensor, surfaced via AddressSanitizer.

ctx_copy is nullptr-initialized and ggml_free() handles NULL safely,
so the new release is guard-free.

* Update ggml/src/ggml-opt.cpp

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: realorko <realorko@nowhere.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-04-08 17:40:15 +02:00
Yuri Khrustalev 660600081f server: respect the ignore eos flag (#21203) 2026-04-08 17:12:15 +02:00
Aldehir Rojas d9a12c82f0 vocab : remove </s> eog token if gemma4 (#21492) 2026-04-08 09:53:06 -05:00
Georgi Gerganov 4a05e0c566 webui : send both backend_sampling == false/true (#18781)
* webui : send both backend_sampling == false/true

* feat: Parameter sync

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2026-04-08 16:35:52 +02:00
John Eismeier e9fd96283d Propose fix a couple of typos (#21581)
Signed-off-by: John E <jeis4wpi@outlook.com>
2026-04-08 16:29:03 +02:00
Erik Scholz 3ba12fed0a kv-cache : extend cache quantization checks (#21586)
to also check for enabled flash attention, instead of just auto.
2026-04-08 16:08:57 +03:00
Reese Levine 5473949070 webgpu : Query for adapter support when registering WebGPU backend (#21579) 2026-04-08 16:08:29 +03:00
Pasha Khosravi dcdcbad42a metal: Q1_0 backend (#21528)
* initial Q1_0 Metal backend

* tuning q1_0 metal kernels

* add Q1_0 to test-backend-ops

* add Q1_0<->F32 copy test

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-04-08 16:07:47 +03:00
Georgi Gerganov 5764d7c6a6 gemma : perform per-layer projections in the first layer (#21612)
* gemma : reduce graph splits by keeping per-layer ops in the input layer

* gemma : put the per-layer proj in the first layer

* cont : move the projection before the layer loop
2026-04-08 16:06:30 +03:00
96 changed files with 1345 additions and 526 deletions
+5 -13
View File
@@ -75,21 +75,13 @@ android:
- examples/llama.android/**
server/webui:
- changed-files:
- all:
- any-glob-to-any-file:
- tools/server/webui/**
- tools/server/public/**
- all-globs-to-all-files:
- '!tools/server/webui/**'
- '!tools/server/public/**'
- any-glob-to-any-file:
- tools/server/webui/**
- tools/server/public/**
server:
- changed-files:
- all:
- any-glob-to-any-file:
- tools/server/**
- all-globs-to-all-files:
- '!tools/server/webui/**'
- '!tools/server/public/**'
- any-glob-to-any-file:
- tools/server/**
+11 -1
View File
@@ -591,6 +591,10 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
for (const auto & f : files) {
if (gguf_filename_is_model(f.path) &&
std::regex_search(f.path, pattern)) {
auto split = get_gguf_split_info(f.path);
if (split.count > 1 && split.index != 1) {
continue;
}
return f;
}
}
@@ -600,6 +604,10 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
if (tag.empty()) {
for (const auto & f : files) {
if (gguf_filename_is_model(f.path)) {
auto split = get_gguf_split_info(f.path);
if (split.count > 1 && split.index != 1) {
continue;
}
return f;
}
}
@@ -618,6 +626,7 @@ static void list_available_gguf_files(const hf_cache::hf_files & files) {
}
struct hf_plan {
hf_cache::hf_file primary;
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
};
@@ -663,6 +672,7 @@ static hf_plan get_hf_plan(const common_params_model & model,
}
}
plan.primary = primary;
plan.model_files = get_split_files(all, primary);
if (opts.download_mmproj) {
@@ -749,7 +759,7 @@ common_download_model_result common_download_model(const common_params_model
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.model_files[0].final_path;
result.model_path = hf.primary.final_path;
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
+17
View File
@@ -251,6 +251,23 @@ value binary_expression::execute_impl(context & ctx) {
return res;
}
// Python-style string repetition
// TODO: support array/tuple repetition (e.g., [1, 2] * 3 → [1, 2, 1, 2, 1, 2])
if (op.value == "*" &&
((is_val<value_string>(left_val) && is_val<value_int>(right_val)) ||
(is_val<value_int>(left_val) && is_val<value_string>(right_val)))) {
const auto & str = is_val<value_string>(left_val) ? left_val->as_string() : right_val->as_string();
const int64_t repeat = is_val<value_int>(right_val) ? right_val->as_int() : left_val->as_int();
auto res = mk_val<value_string>();
if (repeat <= 0) {
return res;
}
for (int64_t i = 0; i < repeat; ++i) {
res->val_str = res->val_str.append(str);
}
return res;
}
// String membership
if (is_val<value_string>(left_val) && is_val<value_string>(right_val)) {
// case: "a" in "abc"
+90 -3
View File
@@ -1,4 +1,5 @@
#include "runtime.h"
#include "unicode.h"
#include "value.h"
// for converting from JSON to jinja values
@@ -154,6 +155,83 @@ static value test_compare_fn(const func_args & args) {
return mk_val<value_bool>(value_compare(args.get_pos(0), args.get_pos(1), op));
}
static void append_codepoint_as_ascii_json_escape(std::string & out, uint32_t codepoint) {
auto append_u16 = [&out](uint32_t value) {
char buf[8];
snprintf(buf, sizeof(buf), "\\u%04x", static_cast<unsigned int>(value));
out += buf;
};
if (codepoint <= 0xFFFF) {
append_u16(codepoint);
return;
}
codepoint -= 0x10000;
append_u16(0xD800 + ((codepoint >> 10) & 0x3FF));
append_u16(0xDC00 + (codepoint & 0x3FF));
}
static std::string json_ensure_ascii_preserving_format(const std::string & json_str) {
std::string output;
output.reserve(json_str.size());
bool in_string = false;
bool escaped = false;
for (size_t pos = 0; pos < json_str.size();) {
const char ch = json_str[pos];
if (!in_string) {
output.push_back(ch);
if (ch == '"') {
in_string = true;
}
++pos;
continue;
}
if (escaped) {
output.push_back(ch);
escaped = false;
++pos;
continue;
}
if (ch == '\\') {
output.push_back(ch);
escaped = true;
++pos;
continue;
}
if (ch == '"') {
output.push_back(ch);
in_string = false;
++pos;
continue;
}
const unsigned char uch = static_cast<unsigned char>(ch);
if (uch < 0x80) {
output.push_back(ch);
++pos;
continue;
}
auto parsed = common_parse_utf8_codepoint(json_str, pos);
if (parsed.status != utf8_parse_result::SUCCESS) {
output += "\\ufffd";
++pos;
continue;
}
append_codepoint_as_ascii_json_escape(output, parsed.codepoint);
pos += parsed.bytes_consumed;
}
return output;
}
static value tojson(const func_args & args) {
args.ensure_count(1, 5);
value val_ascii = args.get_kwarg_or_pos("ensure_ascii", 1);
@@ -169,16 +247,17 @@ static value tojson(const func_args & args) {
if (is_val<value_int>(val_indent)) {
indent = static_cast<int>(val_indent->as_int());
}
if (val_ascii->as_bool()) { // undefined == false
throw not_implemented_exception("tojson ensure_ascii=true not implemented");
}
if (val_sort->as_bool()) { // undefined == false
throw not_implemented_exception("tojson sort_keys=true not implemented");
}
const bool ensure_ascii = val_ascii->as_bool(); // undefined == false
auto separators = (is_val<value_array>(val_separators) ? val_separators : mk_val<value_array>())->as_array();
std::string item_sep = separators.size() > 0 ? separators[0]->as_string().str() : (indent < 0 ? ", " : ",");
std::string key_sep = separators.size() > 1 ? separators[1]->as_string().str() : ": ";
std::string json_str = value_to_json(args.get_pos(0), indent, item_sep, key_sep);
if (ensure_ascii) {
json_str = json_ensure_ascii_preserving_format(json_str);
}
return mk_val<value_string>(json_str);
}
@@ -460,6 +539,10 @@ const func_builtins & value_int_t::get_builtins() const {
int64_t val = args.get_pos(0)->as_int();
return mk_val<value_int>(val < 0 ? -val : val);
}},
{"int", [](const func_args & args) -> value {
args.ensure_vals<value_int>();
return mk_val<value_int>(args.get_pos(0)->as_int());
}},
{"float", [](const func_args & args) -> value {
args.ensure_vals<value_int>();
double val = static_cast<double>(args.get_pos(0)->as_int());
@@ -486,6 +569,10 @@ const func_builtins & value_float_t::get_builtins() const {
int64_t val = static_cast<int64_t>(args.get_pos(0)->as_float());
return mk_val<value_int>(val);
}},
{"float", [](const func_args & args) -> value {
args.ensure_vals<value_float>();
return mk_val<value_float>(args.get_pos(0)->as_float());
}},
{"safe", tojson},
{"string", tojson},
{"tojson", tojson},
+5 -1
View File
@@ -1173,7 +1173,11 @@ struct ggml_cuda_graph {
std::vector<cudaGraphNode_t> nodes;
bool disable_due_to_gpu_arch = false;
bool warmup_complete = false;
std::vector<ggml_tensor> nodes_copy;
struct node_properties {
ggml_tensor node;
void * node_src_data_ptrs[GGML_MAX_SRC];
};
std::vector<node_properties> node_props;
bool is_enabled() const {
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+14 -7
View File
@@ -2979,18 +2979,25 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
// Check if the graph size has changed
if ((int)graph->nodes_copy.size() != cgraph->n_nodes) {
if ((int)graph->node_props.size() != cgraph->n_nodes) {
res = true;
graph->nodes_copy.resize(cgraph->n_nodes);
graph->node_props.resize(cgraph->n_nodes);
}
for (int i = 0; i < cgraph->n_nodes; i++) {
if (!res) {
if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
res = true;
}
ggml_cuda_graph::node_properties prop = {};
memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
// if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see:
// https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188
for (int j = 0; j < GGML_MAX_SRC; ++j) {
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr;
}
memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor));
if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
res = true;
}
graph->node_props[i] = prop;
}
return res;
+10
View File
@@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
suffix = ne00 % 4 == 0 ? "_4" : "";
}
} break;
case GGML_TYPE_Q1_0:
{
nsg = N_SG_Q1_0;
nr0 = N_R0_Q1_0;
} break;
case GGML_TYPE_Q4_0:
{
nsg = N_SG_Q4_0;
@@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
smem = 32*sizeof(float)*nr0;
suffix = ne00 % 4 == 0 ? "_4" : "";
} break;
case GGML_TYPE_Q1_0:
{
nsg = N_SG_Q1_0;
nr0 = N_R0_Q1_0;
} break;
case GGML_TYPE_Q4_0:
{
nsg = N_SG_Q4_0;
+2
View File
@@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
default:
return false;
}
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
+3
View File
@@ -8,6 +8,9 @@
//
// TODO: for optimal performance, become function of the device and work size
#define N_R0_Q1_0 8
#define N_SG_Q1_0 2
#define N_R0_Q4_0 4
#define N_SG_Q4_0 2
+1
View File
@@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
op->src[0]->type == GGML_TYPE_F16 ||
op->src[0]->type == GGML_TYPE_BF16 ||
op->src[0]->type == GGML_TYPE_Q1_0 ||
op->src[0]->type == GGML_TYPE_Q4_0 ||
op->src[0]->type == GGML_TYPE_Q4_1 ||
op->src[0]->type == GGML_TYPE_Q5_0 ||
+189
View File
@@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
}
#endif
template <typename type4x4>
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
device const uint8_t * qs = xb->qs;
const float d = xb->d;
const float neg_d = -d;
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
const uint8_t b0 = qs[byte_offset];
const uint8_t b1 = qs[byte_offset + 1];
float4x4 reg_f;
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
const float d = xb->d;
const float neg_d = -d;
const int base = il * 4;
const uint8_t byte = xb->qs[base / 8];
const int s = base % 8;
float4 reg_f;
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
reg = (type4) reg_f;
}
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
@@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
}
}
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
float sum_abs = 0.0f;
for (int j = 0; j < QK1_0; j++) {
sum_abs += fabs(src[j]);
}
dst.d = sum_abs / QK1_0;
for (int j = 0; j < QK1_0 / 8; j++) {
dst.qs[j] = 0;
}
for (int j = 0; j < QK1_0; j++) {
if (src[j] >= 0.0f) {
dst.qs[j / 8] |= (1 << (j % 8));
}
}
}
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
@@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32(
}
}
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
device const uint8_t * qs = qb_curr->qs + il / 8;
const uint8_t b0 = qs[0];
const uint8_t b1 = qs[1];
float acc = 0.0f;
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
acc += select(0.0f, yl[10], bool(b1 & 0x04));
acc += select(0.0f, yl[11], bool(b1 & 0x08));
acc += select(0.0f, yl[12], bool(b1 & 0x10));
acc += select(0.0f, yl[13], bool(b1 & 0x20));
acc += select(0.0f, yl[14], bool(b1 & 0x40));
acc += select(0.0f, yl[15], bool(b1 & 0x80));
return qb_curr->d * (2.0f * acc - sumy);
}
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
@@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl(
}
}
template<int nr0, typename args_t>
void kernel_mul_mv_q1_0_f32_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
const int nb = args.ne00/QK1_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
device const float * y = (device const float *) (src1 + offset1);
device const block_q1_0 * ax[nr0];
for (int row = 0; row < nr0; ++row) {
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
}
float yl[16];
float sumf[nr0] = {0.f};
const short ix = (tiisg/8);
const short il = (tiisg%8)*16;
device const float * yb = y + ix*QK1_0 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
float sumy = 0.f;
FOR_UNROLL (short i = 0; i < 16; i++) {
yl[i] = yb[i];
sumy += yb[i];
}
FOR_UNROLL (short row = 0; row < nr0; row++) {
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
}
yb += QK1_0 * (N_SIMDWIDTH/8);
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < args.ne01) {
dst_f32[first_row + row] = tot;
}
}
}
[[host_name("kernel_mul_mv_q1_0_f32")]]
kernel void kernel_mul_mv_q1_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q4_0_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
@@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
#endif
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
@@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q(
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
@@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32(
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
@@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
@@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
@@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -9893,6 +10079,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
@@ -9916,6 +10103,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@@ -10070,6 +10258,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
+1
View File
@@ -589,6 +589,7 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
ggml_backend_buffer_free(opt_ctx->buf_cpu);
ggml_free(opt_ctx->ctx_static);
ggml_free(opt_ctx->ctx_cpu);
ggml_free(opt_ctx->ctx_copy);
delete opt_ctx;
}
+4
View File
@@ -44,6 +44,10 @@ void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
} break;
case 512: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_sycl_flash_attn_ext_tile_case<512, 512>(ctx, dst);
} break;
case 576: {
GGML_ASSERT(V->ne[0] == 512);
ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
+23 -128
View File
@@ -67,6 +67,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
@@ -124,6 +130,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
@@ -131,134 +143,6 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co
return 0;
}
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
return 0;
}
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
return 0;
}
static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
if(fast_fp16_available(cc))
return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
@@ -1293,6 +1177,16 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggm
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
return;
}
// ncols2=2 and ncols2=1 fallbacks only for cases where ncols=2 config exists (DKQ == DV).
// For DKQ == 576, DV == 512 only GQA-optimized variants are implemented.
if constexpr (DKQ == DV) {
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
return;
}
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
return;
}
}
if constexpr (DV <= 256) {
@@ -1347,5 +1241,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
extern DECL_FATTN_TILE_CASE(112, 112);
extern DECL_FATTN_TILE_CASE(128, 128);
extern DECL_FATTN_TILE_CASE(256, 256);
extern DECL_FATTN_TILE_CASE(512, 512);
extern DECL_FATTN_TILE_CASE(576, 512);
+7
View File
@@ -664,4 +664,11 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_F16)
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_0)
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_1)
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q8_0)
#endif // GGML_SYCL_FATTN_VEC_HPP
+3 -1
View File
@@ -34,6 +34,7 @@
FATTN_VEC_CASE( 64, type_K, type_V) \
FATTN_VEC_CASE(128, type_K, type_V) \
FATTN_VEC_CASE(256, type_K, type_V) \
FATTN_VEC_CASE(512, type_K, type_V) \
static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_tensor * Q = dst->src[0];
@@ -141,6 +142,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const
case 128:
case 112:
case 256:
case 512:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
@@ -185,7 +187,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const
}
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
const bool can_use_vector_kernel = Q->ne[0] <= 512 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// Todo: Use the XMX kernel if possible:
+16 -5
View File
@@ -411,11 +411,22 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
assert(tensor->view_src->buffer->buft == buffer->buft);
return GGML_STATUS_SUCCESS;
}
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
!g_ggml_sycl_disable_optimize) {
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra;
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
if (!g_ggml_sycl_disable_optimize) {
// set reorder extra buffer based on supported type
switch (tensor->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:{
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra;
ctx->tensor_extras.push_back(extra);
break;
}
default:
break;
}
}
if (ggml_is_quantized(tensor->type)) {
@@ -0,0 +1,6 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.hpp"
DECL_FATTN_TILE_CASE(512, 512);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_F16);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q8_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_F16);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_F16);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_F16);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_F16);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_F16);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
@@ -5,3 +5,4 @@
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
@@ -6,8 +6,8 @@
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_VEC4)
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
#if defined(A_TYPEV4)
layout (binding = 0) readonly buffer AV4 {A_TYPEV4 data_a_v4[];};
#endif
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
@@ -17,11 +17,11 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
#ifdef B_TYPEV2
layout (binding = 1) readonly buffer BV2 {B_TYPEV2 data_b_v2[];};
#endif
#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#ifdef B_TYPEV4
layout (binding = 1) readonly buffer BV4 {B_TYPEV4 data_b_v4[];};
#endif
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
@@ -41,7 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
const FLOAT_TYPEV2 dm = vec2(data_a[ib0 + i].dm);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
@@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -11,8 +11,8 @@ FLOAT_TYPE get_dm(uint ib) {
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
FLOAT_TYPEV2 get_dm(uint ib) {
return FLOAT_TYPEV2(data_a_packed32[ib].dm);
}
#endif
@@ -23,9 +23,9 @@ FLOAT_TYPE get_dm(uint ib) {
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
FLOAT_TYPEV2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
return FLOAT_TYPEV2(data_a_packed32[ib_k].dm);
}
#endif
@@ -304,7 +304,7 @@ vec2 get_dm_scale(uint ib, uint iqs) {
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
@@ -422,7 +422,7 @@ vec2 get_dm(uint ib, uint iqs) {
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
// the -1 cancels out the bias in iq1s_grid_gpu
return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));
return FLOAT_TYPEV2(dl, dl * (delta - 1));
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
@@ -125,8 +125,8 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
#define SHMEM_STRIDE (BK / 2 + 1)
#endif
shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
shared FLOAT_TYPEV2 buf_a[BM * SHMEM_STRIDE];
shared FLOAT_TYPEV2 buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
@@ -258,17 +258,17 @@ void main() {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
ACC_TYPEV2 sums[WMITER * TM * WNITER * TN/2];
#if defined(DATA_A_F32) || defined(DATA_A_F16)
FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
FLOAT_TYPE_VEC4 cache_b;
FLOAT_TYPEV4 cache_a[WMITER * TM];
FLOAT_TYPEV4 cache_b;
#else
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
FLOAT_TYPE_VEC2 cache_b;
FLOAT_TYPEV2 cache_a[WMITER * TM];
FLOAT_TYPEV2 cache_b;
#endif
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
sums[i] = ACC_TYPEV2(0.0f, 0.0f);
}
#endif
@@ -3,7 +3,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#if LOAD_VEC_A == 8
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
buf_a[buf_idx ] = aa[0].xy;
buf_a[buf_idx + 1] = aa[0].zw;
buf_a[buf_idx + 2] = aa[1].xy;
@@ -11,38 +11,38 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#elif LOAD_VEC_A == 4
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
data_a[idx + 1]);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
data_a[idx + 1]);
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_BF16)
#if LOAD_VEC_A == 4
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
#else // LOAD_VEC_BATCH_A == 2
const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
TO_FLOAT_TYPE(data_a[idx + 1]));
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
TO_FLOAT_TYPE(data_a[idx + 1]));
} else if (idx_m < p.M && block + row * 2 < end_k) {
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
} else {
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
#elif defined(DATA_A_Q4_0)
@@ -57,10 +57,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy);
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v0.zw);
buf_a[buf_idx + 8] = FLOAT_TYPEV2(v1.xy);
buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.zw);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -73,10 +73,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy);
buf_a[buf_idx + 1 ] = FLOAT_TYPEV2(v0.zw);
buf_a[buf_idx + 8 ] = FLOAT_TYPEV2(v1.xy);
buf_a[buf_idx + 9 ] = FLOAT_TYPEV2(v1.zw);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -92,8 +92,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xz);
buf_a[buf_idx + 8] = FLOAT_TYPEV2(v.yw);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -112,10 +112,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xz);
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v1.xz);
buf_a[buf_idx + 8] = FLOAT_TYPEV2(v0.yw);
buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.yw);
#elif defined(DATA_A_Q8_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -128,8 +128,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -147,8 +147,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
#elif defined(DATA_A_Q3_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -171,8 +171,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy);
const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x),
dl * (qs.y - hm.y));
buf_a[buf_idx] = FLOAT_TYPEV2(dl * (qs.x - hm.x),
dl * (qs.y - hm.y));
#elif defined(DATA_A_Q4_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -206,8 +206,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m));
buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m));
#elif defined(DATA_A_Q5_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -244,8 +244,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
const vec4 q = vec4(unpack8(qs | qh));
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m));
buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m));
#elif defined(DATA_A_Q6_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -267,7 +267,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303;
const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale;
buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y);
buf_a[buf_idx] = FLOAT_TYPEV2(q.x, q.y);
#elif defined(DATA_A_IQ1_S)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -284,8 +284,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
[[unroll]] for (int k = 0; k < 4; ++k) {
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
}
#elif defined(DATA_A_IQ1_M)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
@@ -306,8 +306,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
[[unroll]] for (int k = 0; k < 4; ++k) {
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
}
#elif defined(DATA_A_IQ2_XXS)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
@@ -332,14 +332,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
(sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
(sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
(sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
(sign & 128) != 0 ? -grid1.w : grid1.w);
buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x,
(sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z,
(sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x,
(sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z,
(sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_XS)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -358,14 +358,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
(sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
(sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
(sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
(sign & 128) != 0 ? -grid1.w : grid1.w);
buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x,
(sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z,
(sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x,
(sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z,
(sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_S)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -386,14 +386,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
(sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
(sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
(sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
(sign & 128) != 0 ? -grid1.w : grid1.w);
buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x,
(sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z,
(sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x,
(sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z,
(sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ3_XXS)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -414,10 +414,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint grid = iq3xxs_grid[qs];
const vec4 v = db * vec4(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
(sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
(sign & 8) != 0 ? -v.w : v.w);
buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x,
(sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z,
(sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ3_S)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -436,10 +436,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
const vec4 v = db * vec4(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
(sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
(sign & 8) != 0 ? -v.w : v.w);
buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x,
(sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z,
(sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ4_XS)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -456,8 +456,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = float(data_a[ib].d);
const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -468,10 +468,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF],
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
kvalues_iq4nl[vui >> 12]);
buf_a[buf_idx ] = d * FLOAT_TYPEV2(kvalues_iq4nl[vui & 0xF],
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
buf_a[buf_idx + 8] = d * FLOAT_TYPEV2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
kvalues_iq4nl[vui >> 12]);
#elif defined(DATA_A_MXFP4)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
@@ -483,10 +483,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d,
kvalues_mxfp4[vui2 & 0xF] * d);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d,
kvalues_mxfp4[vui2 >> 4] * d);
buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d,
kvalues_mxfp4[vui2 & 0xF] * d);
buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
kvalues_mxfp4[vui2 >> 4] * d);
#endif
}
@@ -496,7 +496,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
// Not supported for b_type bf16 because bf16mat2x4 does not exist
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
@@ -505,9 +505,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
@@ -515,12 +515,12 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
const uint idx = pos_b + col * p.stride_b + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
} else if (idx_n < p.N && block + row * 2 < end_k) {
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
@@ -531,7 +531,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
buf_b[buf_idx + 0] = bb[0].xy;
buf_b[buf_idx + 1] = bb[0].zw;
buf_b[buf_idx + 2] = bb[1].xy;
@@ -541,9 +541,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
#if defined(DATA_B_BF16)
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
#else
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
@@ -553,14 +553,14 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
} else if (row_i < _ne1 && block + row * 2 < end_k) {
const u16vec2 row_idx = row_ids[col];
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
}
#endif
}
@@ -21,7 +21,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm);
}
#endif
}
@@ -72,7 +72,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm);
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
}
#endif
@@ -203,7 +203,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib_k].dm);
buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147
}
}
@@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
}
}
@@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
buf_a[buf_ib].dm = FLOAT_TYPEV2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
}
}
@@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint is = iqs_k / 4;
const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales));
buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales));
}
}
@@ -426,7 +426,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo
const uint ib_inner = ib % 4;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
buf_b[buf_ib].ds = FLOAT_TYPEV2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
@@ -436,7 +436,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
} else {
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
buf_b[buf_ib].ds = FLOAT_TYPEV2(0.0f);
}
buf_b[buf_ib].qs[iqs * 4 ] = 0;
@@ -8,7 +8,7 @@ struct block_a_cache {
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
FLOAT_TYPE_VEC2 dm;
FLOAT_TYPEV2 dm;
};
#elif defined(DATA_A_Q5_0)
#define QUANT_R_MMQ 2
@@ -22,7 +22,7 @@ struct block_a_cache {
struct block_a_cache {
uint32_t qs[16/4];
uint32_t qh;
FLOAT_TYPE_VEC2 dm;
FLOAT_TYPEV2 dm;
};
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
@@ -43,36 +43,36 @@ struct block_a_cache {
struct block_a_cache {
uint32_t qs[2];
u8vec2 scales;
FLOAT_TYPE_VEC2 dm;
FLOAT_TYPEV2 dm;
};
#elif defined(DATA_A_Q3_K)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[4];
FLOAT_TYPE_VEC2 d_scales;
FLOAT_TYPEV2 d_scales;
};
#elif defined(DATA_A_Q4_K)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[4];
FLOAT_TYPE_VEC2 dm;
FLOAT_TYPEV2 dm;
};
#elif defined(DATA_A_Q5_K)
#define QUANT_R_MMQ 1
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE_VEC2 dm;
FLOAT_TYPEV2 dm;
};
#elif defined(DATA_A_Q6_K)
#define QUANT_R_MMQ 1
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE_VEC2 d_scales;
FLOAT_TYPEV2 d_scales;
};
#endif
struct block_b_cache
{
int32_t qs[8];
FLOAT_TYPE_VEC2 ds;
FLOAT_TYPEV2 ds;
};
@@ -446,8 +446,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
base_dict["FLOAT16"] = "1";
}
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
base_dict["ACC_TYPEV2"] = f16acc ? "f16vec2" : "vec2";
if (f16acc) {
base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}
@@ -514,10 +514,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};
const std::map<std::string, std::string> float_type_dict_f16 = {
{"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")},
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")},
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")},
{"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
{"FLOAT_TYPEV2", FLOAT_TYPE(2, "f16")},
{"FLOAT_TYPEV4", FLOAT_TYPE(4, "f16")},
{"FLOAT_TYPEV8", FLOAT_TYPE(8, "f16")},
};
// Shaders with f16 B_TYPE
@@ -536,9 +536,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
const std::map<std::string, std::string> float_type_dict_bf16 = {
{"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")},
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")},
{"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
{"FLOAT_TYPEV2", FLOAT_TYPE(2, "bf16")},
{"FLOAT_TYPEV4", FLOAT_TYPE(4, "bf16")},
};
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
@@ -569,10 +569,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
const std::map<std::string, std::string> float_type_dict = {
{"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)},
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)},
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)},
{"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
{"FLOAT_TYPEV2", FLOAT_TYPE(2, tname)},
{"FLOAT_TYPEV4", FLOAT_TYPE(4, tname)},
{"FLOAT_TYPEV8", FLOAT_TYPE(8, tname)},
};
// don't generate f32 variants for coopmat2
@@ -676,36 +676,36 @@ void process_shaders() {
}
}
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}};
for (const auto& tname : type_names) {
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
// mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
}
#endif
@@ -726,9 +726,9 @@ void process_shaders() {
string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}});
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}});
// Norms
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+28 -13
View File
@@ -4033,8 +4033,14 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
static ggml_backend_webgpu_reg_context ctx;
static ggml_backend_reg reg = {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_webgpu_reg_i,
/* .context = */ &ctx,
};
ctx.name = GGML_WEBGPU_NAME;
ctx.device_count = 1;
ctx.device_count = 0;
wgpu::InstanceDescriptor instance_descriptor{};
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
@@ -4053,19 +4059,28 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
ctx.webgpu_global_ctx->instance = std::move(inst);
#ifdef __EMSCRIPTEN__
if (ctx.webgpu_global_ctx->instance == nullptr) {
GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
return nullptr;
}
#endif
GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
wgpu::Adapter adapter;
if (ctx.webgpu_global_ctx->instance != nullptr) {
wgpu::RequestAdapterOptions options = {};
// probe for adapter support
ctx.webgpu_global_ctx->instance.WaitAny(
ctx.webgpu_global_ctx->instance.RequestAdapter(
&options, wgpu::CallbackMode::AllowSpontaneous,
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
adapter = std::move(_adapter);
}),
UINT64_MAX);
}
if (adapter != nullptr) {
ctx.device_count = 1;
}
static ggml_backend_reg reg = {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_webgpu_reg_i,
/* .context = */ &ctx,
};
return &reg;
}
Binary file not shown.
+111
View File
@@ -0,0 +1,111 @@
ied 4 ½ months
__ggml_vocab_test__
Äpfel
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
this is 🦙.cpp
__ggml_vocab_test__
w048 7tuijk dsdfhu
__ggml_vocab_test__
нещо на Български
__ggml_vocab_test__
កាន់តែពិសេសអាចខលចេញ
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
Hello
__ggml_vocab_test__
(
__ggml_vocab_test__
=
__ggml_vocab_test__
' era
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
__ggml_vocab_test__
333
__ggml_vocab_test__
3333
__ggml_vocab_test__
33333
__ggml_vocab_test__
333333
__ggml_vocab_test__
3333333
__ggml_vocab_test__
33333333
__ggml_vocab_test__
333333333
__ggml_vocab_test__
Cửa Việt
__ggml_vocab_test__
discards
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
+46
View File
@@ -0,0 +1,46 @@
1178 236743 236812 47041 3794
239122 22744 535
236743
138
139
255968
107
108
109
255968 107
9259 1902
26352 1902
9259 4109
26352 4109
26352 4109 236888
9259 236764 1902 236888
26352 236764 1902 236888
672 563 236743 478 397 404 391 236761 12362
236765 236771 236812 236828 236743 236832 11372 12065 31806 3405 9360
1337 12515 1333 4632 165543 3830
234889 63031 219876 66212 239077 237907 144494
242015 568 7382 236768 236743 247717 237243 248989 238178 568 43819 111730 150567 236768 113452 568 8960 64334 600 815 1061 1852 8369 236768
9259
26352
138 9259
139 9259
140 9259
140 9259 107 140 9259
568
107 578
236789 6933
9259 236764 570 236789 712 236888 2088 659 611 170124 2360 62133 237075 17641 11700 236770 236800 236770 236812 236770 236810 236770 237471 238352
123947
236800
236800 236800
236800 236800 236800
236800 236800 236800 236800
236800 236800 236800 236800 236800
236800 236800 236800 236800 236800 236800
236800 236800 236800 236800 236800 236800 236800
236800 236800 236800 236800 236800 236800 236800 236800
236800 236800 236800 236800 236800 236800 236800 236800 236800
236780 29719 33154
2243 2206
107 236743 108 236743 109 236743 255968 236743 255969 236743 255968 107 138 107 139 107 140 107 141 107 242015 568 7382 236768 236743 247717 237243 248989 238178 568 43819 111730 150567 236768 113452 236743 478 397 404 391 478 397 404 391 236743 236800 236743 236800 236800 236743 236800 236800 236800 236743 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236800 236800 236800 236743 236800 236761 236800 236743 236800 856 236800 236743 236800 1390 236800 90986 92814 63031 219876 66212 241702 2360 62133 237075 17641 11700 236770 236800 236770 236812 236770 236810 236770 237471 238352 80448 120697 210119 1333 4632 165543 3830 9451 159561 2629 2629 2717 84491 19938 123947 38950 10371 564 236789 560 1010 756 151812 668 236789 236751 993 236764 756 1357 611 2889 236881 756 236792 711 2889 564 236789 859 1386 625 236764 756 236796 611 1133 1070 11115 236881 1191 236789 32541 496 236789 95635
+16 -16
View File
@@ -558,20 +558,20 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
// example: https://github.com/ggml-org/llama.cpp/pull/17548
//
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer)
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
{LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
{LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
@@ -708,9 +708,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
// altup / laurel (gemma 3n)
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+2 -2
View File
@@ -2942,7 +2942,7 @@ llama_context * llama_init_from_model(
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
}
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) {
const uint32_t blck_size = ggml_blck_size(params.type_k);
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
@@ -2953,7 +2953,7 @@ llama_context * llama_init_from_model(
}
}
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) {
const uint32_t blck_size = ggml_blck_size(params.type_v);
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
+10 -9
View File
@@ -4211,13 +4211,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0);
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
@@ -4276,9 +4277,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
if (n_embd_per_layer > 0) {
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_per_layer * n_layer}, 0);
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_per_layer}, 0);
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0);
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0);
}
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+1 -1
View File
@@ -534,9 +534,9 @@ struct llama_model {
struct ggml_tensor * conv1d_b = nullptr;
// gemma3n altup
struct ggml_tensor * tok_embd_per_layer = nullptr;
struct ggml_tensor * altup_proj = nullptr;
struct ggml_tensor * altup_unembd_proj = nullptr;
struct ggml_tensor * per_layer_tok_embd = nullptr;
struct ggml_tensor * per_layer_model_proj = nullptr;
struct ggml_tensor * per_layer_proj_norm = nullptr;
+40 -3
View File
@@ -659,8 +659,17 @@ struct llm_tokenizer_bpe_session {
if (token == LLAMA_TOKEN_NULL) {
for (auto j = str.begin(); j != str.end(); ++j) {
std::string byte_str(1, *j);
auto token_multibyte = vocab.text_to_token(byte_str);
llama_token token_multibyte = LLAMA_TOKEN_NULL;
if (tokenizer.byte_encode) {
std::string byte_str(1, *j);
token_multibyte = vocab.text_to_token(byte_str);
} else {
// For non-byte-encoded BPE (e.g. gemma-4), byte tokens use <0xXX> format
static const char * hex = "0123456789ABCDEF";
const uint8_t ch = (uint8_t)*j;
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
token_multibyte = vocab.text_to_token(buf);
}
if (token_multibyte != LLAMA_TOKEN_NULL) {
output.push_back(token_multibyte);
}
@@ -2558,7 +2567,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|| t.first == "[EOS]" // Kimi-K2
|| t.first == "<|end_of_text|>"
|| t.first == "<end_of_utterance>" // smoldocling
|| t.first == "<turn|>" // gemma4
|| t.first == "<eos>" // gemma4
|| t.first == "<turn|>" // gemma4
|| t.first == "<|tool_response>" // gemma4
|| t.first == "<end▁of▁sentence>" // deepseek-ocr
) {
@@ -2645,6 +2655,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
}
}
// workaround for gemma4 and paddleocr: do not include </s> as an eog token
{
bool has_tool_response = false;
bool has_s = false;
llama_token s_id = LLAMA_TOKEN_NULL;
for (auto tid : special_eog_ids) {
const auto & text = id_to_token[tid].text;
if (text == "<|tool_response>") {
has_tool_response = true;
} else if (text == "</s>") {
has_s = true;
s_id = tid;
}
}
if (has_tool_response && has_s) {
special_eog_ids.erase(s_id);
auto & attr = id_to_token[s_id].attr;
attr = LLAMA_TOKEN_ATTR_NORMAL;
LLAMA_LOG_WARN("%s: special_eog_ids contains '<|tool_response>', removing '</s>' token from EOG list\n", __func__);
}
}
}
// build special tokens cache
+36 -32
View File
@@ -1,5 +1,12 @@
#include "models.h"
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
GGML_ASSERT(idx < (int) x->ne[2]);
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
}
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params),
model(model),
@@ -22,8 +29,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
// TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_iswa();
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
ggml_tensor * inp_per_layer = build_inp_per_layer();
ggml_build_forward_expand(gf, inp_per_layer);
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
// inpL now has only 1 altup, project it to the rest of the altups
// these "added" altups will be concat to the last dim of inpL
@@ -37,8 +47,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
cb(inpL, "inp_stacked", -1);
}
// inpL now has shape: [n_embd, n_tokens, n_altup]
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
// inpL now has shape: [n_embd, n_tokens, n_altup]
for (int il = 0; il < n_layer; ++il) {
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
@@ -49,8 +58,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
// predicted value will go through self-attention and laurel
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
cur = active_prediction;
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens]
cur = active_prediction;
cb(cur, "active_prediction", il);
// norm
@@ -151,12 +160,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
ggml_tensor * first_prediction; // [n_embd, n_tokens]
{
first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens]
first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
cb(first_prediction, "first_prediction_gated", il);
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens]
first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
cb(first_prediction, "first_prediction_scaled", il);
@@ -167,7 +177,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
}
// equivalent to python code: corrected_predictions[1:] += first_prediction
{
ggml_tensor * slice_first = view_2d_slice(corrected, 0);
ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0);
ggml_tensor * slice_rest = ggml_view_3d(
ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
@@ -185,7 +195,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
// cur now has multiple altup(s), we want to merge them back to 1 altup
{
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens]
// do a view to skip the first slice (active altup)
ggml_tensor * alt_slice =
ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
@@ -197,9 +207,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
cb(altup_unembd, "altup_unembd", -1);
// equivalent to torch.mean(hidden_states, dim=0)
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens]
for (int i = 0; i < n_altup - 1; ++i) {
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i));
}
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
cb(cur, "unembd_merged", -1);
@@ -235,23 +245,16 @@ ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
}
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
GGML_ASSERT(idx < (int) x->ne[2]);
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
}
// equivalent to get_per_layer_inputs() in python code
// output shape: [n_embd_altup, n_layer, n_tokens]
ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() {
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
ggml_tensor * inp_per_layer;
if (ubatch.token) {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
ggml_set_input(inp->tokens);
res->t_inp_tokens = inp->tokens;
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
inp_per_layer = ggml_get_rows(ctx0, model.per_layer_tok_embd, inp->tokens);
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
cb(inp_per_layer, "inp_per_layer_selected", -1);
@@ -259,10 +262,10 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
} else {
// Vision embedding path: use padding token (ID=0) embedding
// TODO: verify if this is the correct behavior in transformers implementation
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer
const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer
// Extract and dequantize padding token embedding (row 0)
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
// Reshape to [n_embd_altup, n_layer, 1]
@@ -275,18 +278,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
// equivalent to project_per_layer_inputs() in python code
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
// output shape: [n_embd_altup, n_tokens, n_layer]
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS,
-1); // [n_embd_altup, n_layer, n_tokens]
ggml_tensor * per_layer_proj;
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1);
cb(per_layer_proj, "per_layer_proj", -1);
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
cb(inp_per_layer, "inp_per_layer", -1);
@@ -337,7 +341,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso
// input cur shape: [n_embd, n_tokens, n_altup]
// output shape: [n_embd, n_tokens, n_altup]
ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens]
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
cb(modalities, "modalities", il);
@@ -365,7 +369,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, g
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
cb(modalities, "modalities", il);
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act);
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
cb(innovation, "innovation", il);
+37 -28
View File
@@ -1,5 +1,12 @@
#include "models.h"
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
GGML_ASSERT(idx < (int) x->ne[2]);
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
}
llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params),
model(model),
@@ -19,14 +26,17 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
// TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_iswa();
// inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer]
ggml_tensor * inp_per_layer = nullptr;
if (model.tok_embd_per_layer) {
inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
}
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * inp_per_layer = nullptr;
if (model.per_layer_tok_embd) {
inp_per_layer = build_inp_per_layer();
ggml_build_forward_expand(gf, inp_per_layer);
// inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer]
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
}
for (int il = 0; il < n_layer; ++il) {
const int64_t n_embd_head = hparams.n_embd_head_k(il);
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il));
@@ -196,7 +206,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens]
cur = ggml_gelu(ctx0, cur);
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_per_layer, n_tokens]
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
// TODO @ngxson : improve this
if (il == n_layer - 1 && inp_out_ids) {
@@ -248,34 +259,30 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
ggml_build_forward_expand(gf, cur);
}
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
ggml_tensor * llm_build_gemma4_iswa::view_2d_slice(ggml_tensor * x, int idx) {
GGML_ASSERT(idx < (int) x->ne[2]);
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
}
// equivalent to get_per_layer_inputs() in python code
// output shape: [n_embd_per_layer, n_layer, n_tokens]
ggml_tensor * llm_build_gemma4_iswa::get_per_layer_inputs() {
ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() {
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
ggml_tensor * inp_per_layer;
if (ubatch.token) {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
ggml_set_input(inp->tokens);
res->t_inp_tokens = inp->tokens;
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens);
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_per_layer));
inp_per_layer = ggml_scale (ctx0, inp_per_layer, sqrtf((float) n_embd_per_layer));
cb(inp_per_layer, "inp_per_layer_selected", -1);
res->add_input(std::move(inp));
} else {
// Vision embedding path: use padding token (ID=0) embedding
// TODO: verify if this is the correct behavior in transformers implementation
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_per_layer * n_layer
const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_per_layer * n_layer
// Extract and dequantize padding token embedding (row 0)
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
// Reshape to [n_embd_per_layer, n_layer, 1]
@@ -287,21 +294,23 @@ ggml_tensor * llm_build_gemma4_iswa::get_per_layer_inputs() {
// equivalent to project_per_layer_inputs() in python code
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
// inputs_embeds shape: [n_embd, n_tokens]
// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from get_per_layer_inputs)
// inp_batch shape: [n_embd, n_tokens]
// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer)
// output shape: [n_embd_per_layer, n_tokens, n_layer]
ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens);
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS,
-1); // [n_embd_per_layer, n_layer, n_tokens]
// note: this matrix multiplication will be performed in the input layer (i.e. on the CPU)
ggml_tensor * per_layer_proj;
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens);
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS, -1);
cb(per_layer_proj, "per_layer_proj", -1);
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
cb(inp_per_layer, "inp_per_layer", -1);
+9 -6
View File
@@ -256,9 +256,11 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params);
ggml_tensor * calc_magnitude(ggml_tensor * x);
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
ggml_tensor * get_per_layer_inputs();
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
// TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER]
ggml_tensor * build_inp_per_layer();
ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer);
ggml_tensor * gaussian_topk(ggml_tensor * x);
ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il);
ggml_tensor * altup_predict(ggml_tensor * cur, int il);
@@ -272,9 +274,10 @@ struct llm_build_gemma4_iswa : public llm_graph_context {
const int64_t n_embd_per_layer;
llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params);
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
ggml_tensor * get_per_layer_inputs();
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
// TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER]
ggml_tensor * build_inp_per_layer();
ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer);
};
struct llm_build_gemma_embedding : public llm_graph_context {
+1
View File
@@ -124,6 +124,7 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r ARGS ${PROJE
llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-coder ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-coder.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-llm ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-llm.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-falcon ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-falcon.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-gemma-4 ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gemma-4.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gpt-2.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-bpe ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-bpe.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-spm ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-spm.gguf)
+2
View File
@@ -7251,6 +7251,7 @@ static const ggml_type all_types[] = {
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q1_0,
GGML_TYPE_MXFP4, GGML_TYPE_NVFP4,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
@@ -7275,6 +7276,7 @@ static const ggml_type other_types[] = {
GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q1_0,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
+1 -1
View File
@@ -3454,7 +3454,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
},
"replaceAll": {
"type": "boolean",
"description": "Whether to replace all occurences."
"description": "Whether to replace all occurrences."
}
},
"required": ["oldString", "newString"]
+53
View File
@@ -447,6 +447,18 @@ static void test_expressions(testing & t) {
"hello world"
);
test_template(t, "string repetition",
"{{ 'ab' * 3 }}",
json::object(),
"ababab"
);
test_template(t, "reversed string repetition",
"{{ 3 * 'ab' }}",
json::object(),
"ababab"
);
test_template(t, "ternary",
"{{ 'yes' if cond else 'no' }}",
{{"cond", true}},
@@ -693,6 +705,33 @@ static void test_filters(testing & t) {
"\"\\u2713\""
);
test_template(t, "tojson ensure_ascii=true nested object",
"{{ data|tojson(ensure_ascii=true) }}",
{{"data", {
{"text", "\u2713"},
{"items", json::array({"é", {{"snowman", ""}}})}
}}},
"{\"text\": \"\\u2713\", \"items\": [\"\\u00e9\", {\"snowman\": \"\\u2603\"}]}"
);
test_template(t, "tojson ensure_ascii=true indent=2",
"{{ data|tojson(ensure_ascii=true, indent=2) }}",
{{"data", {
{"text", "\u2713"},
{"nested", {{"accent", "é"}}}
}}},
"{\n \"text\": \"\\u2713\",\n \"nested\": {\n \"accent\": \"\\u00e9\"\n }\n}"
);
test_template(t, "tojson ensure_ascii=true preserves existing escapes",
"{{ data|tojson(ensure_ascii=true) }}",
{{"data", {
{"emoji", "😀"},
{"line", "a\nb"}
}}},
"{\"emoji\": \"\\ud83d\\ude00\", \"line\": \"a\\nb\"}"
);
test_template(t, "tojson sort_keys=true",
"{{ data|tojson(sort_keys=true) }}",
{{"data", {{"b", 2}, {"a", 1}}}},
@@ -771,6 +810,12 @@ static void test_filters(testing & t) {
"hello"
);
test_template(t, "int filter on integer is identity",
"{{ value|int }}",
{{"value", 7}},
"7"
);
test_template(t, "none to string",
"{{ x|string }}",
{{"x", nullptr}},
@@ -2458,4 +2503,12 @@ static void test_fuzzing(testing & t) {
t.assert_true("builtin " + type_name + "." + fn_name + " #" + std::to_string(i), fuzz_test_template(tmpl, vars));
}
});
t.test("tojson ensure_ascii=true with invalid utf-8", [&](testing & t) {
t.assert_true("invalid utf-8 does not crash",
fuzz_test_template(
"{{ data|tojson(ensure_ascii=true) }}",
{{"data", std::string("hello\xfe\xffworld")}}
));
});
}
File diff suppressed because one or more lines are too long
+1 -1
View File
@@ -18,7 +18,7 @@
<div style="display: contents">
<script>
{
__sveltekit_1ppa22i = {
__sveltekit_6n4hpv = {
base: new URL('.', location).pathname.slice(0, -1)
};
+3
View File
@@ -3033,6 +3033,8 @@ server_context_meta server_context::get_meta() const {
/* fim_rep_token */ llama_vocab_fim_rep(impl->vocab),
/* fim_sep_token */ llama_vocab_fim_sep(impl->vocab),
/* logit_bias_eog */ impl->params_base.sampling.logit_bias_eog,
/* model_vocab_type */ llama_vocab_type(impl->vocab),
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model),
@@ -3117,6 +3119,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
ctx_server.vocab,
params,
meta->slot_n_ctx,
meta->logit_bias_eog,
data);
task.id_slot = json_value(data, "id_slot", -1);
+3
View File
@@ -39,6 +39,9 @@ struct server_context_meta {
llama_token fim_rep_token;
llama_token fim_sep_token;
// sampling
std::vector<llama_logit_bias> logit_bias_eog;
// model meta
enum llama_vocab_type model_vocab_type;
int32_t model_vocab_n_tokens;
+4 -1
View File
@@ -239,6 +239,7 @@ task_params server_task::params_from_json_cmpl(
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const std::vector<llama_logit_bias> & logit_bias_eog,
const json & data) {
task_params params;
@@ -383,6 +384,8 @@ task_params server_task::params_from_json_cmpl(
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = defaults.sampling.grammar;
std::string grammar_str = json_value(data, "grammar", std::string());
if (!grammar_str.empty()) {
// grammar_type key is set by the server when converting chat template grammars
@@ -562,7 +565,7 @@ task_params server_task::params_from_json_cmpl(
if (params.sampling.ignore_eos) {
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
logit_bias_eog.begin(), logit_bias_eog.end());
}
}
+1
View File
@@ -209,6 +209,7 @@ struct server_task {
const llama_vocab * vocab,
const common_params & params_base,
const int n_ctx_slot,
const std::vector<llama_logit_bias> & logit_bias_eog,
const json & data);
// utility function
+1 -1
View File
@@ -135,7 +135,7 @@ def test_completion_stream_with_openai_library_stops():
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
prompt="System: You are helpful assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
stop=["User:\n", "Assistant:\n"],
max_tokens=200,
stream=True,
@@ -0,0 +1,43 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_ignore_eos_populates_logit_bias():
"""ignore_eos=true must add EOG logit biases to generation_settings."""
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "Once upon a time",
"ignore_eos": True,
"temperature": 0.0,
})
assert res.status_code == 200
# EOG token biases must be present with -inf bias
logit_bias = res.body["generation_settings"]["logit_bias"]
assert len(logit_bias) > 0
for entry in logit_bias:
assert entry["bias"] is None # null in JSON represents -inf
def test_ignore_eos_false_no_logit_bias():
"""ignore_eos=false (default) must NOT add EOG logit biases."""
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "Once upon a time",
"ignore_eos": False,
"temperature": 0.0,
})
assert res.status_code == 200
logit_bias = res.body["generation_settings"]["logit_bias"]
assert len(logit_bias) == 0
@@ -62,10 +62,14 @@
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
);
let lastSyncedConversationModel: string | null = null;
$effect(() => {
if (conversationModel) {
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
lastSyncedConversationModel = conversationModel;
modelsStore.selectModelByName(conversationModel);
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
lastSyncedConversationModel = null;
// auto-select the first loaded model only when nothing is selected yet
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
if (first) modelsStore.selectModelById(first.id);
@@ -291,14 +291,19 @@
title: SETTINGS_SECTION_TITLES.DEVELOPER,
icon: Code,
fields: [
{
key: SETTINGS_KEYS.PRE_ENCODE_CONVERSATION,
label: 'Pre-fill KV cache after response',
type: SettingsFieldType.CHECKBOX
},
{
key: SETTINGS_KEYS.DISABLE_REASONING_PARSING,
label: 'Disable reasoning content parsing',
label: 'Disable server-side thinking extraction',
type: SettingsFieldType.CHECKBOX
},
{
key: SETTINGS_KEYS.EXCLUDE_REASONING_FROM_CONTEXT,
label: 'Exclude reasoning from context',
label: 'Strip thinking from message history',
type: SettingsFieldType.CHECKBOX
},
{
@@ -56,6 +56,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean |
dry_penalty_last_n: undefined,
max_tokens: undefined,
custom: '', // custom json-stringified object
preEncodeConversation: false,
// experimental features
pyInterpreterEnabled: false,
enableContinueGeneration: false
@@ -106,9 +107,9 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
custom: 'Custom JSON parameters to send to the API. Must be valid JSON format.',
showThoughtInProgress: 'Expand thought process by default when generating messages.',
disableReasoningParsing:
'Send reasoning_format=none to prevent server-side extraction of reasoning tokens into separate field',
'Send reasoning_format=none so the server returns thinking tokens inline instead of extracting them into a separate field.',
excludeReasoningFromContext:
'Strip reasoning content from previous messages before sending to the model. When unchecked, reasoning is sent back via the reasoning_content field so the model can see its own chain-of-thought across turns.',
'Strip thinking from previous messages before sending. When off, thinking is sent back via the reasoning_content field so the model sees its own chain-of-thought across turns.',
showRawOutputSwitch:
'Show toggle button to display messages as plain text instead of Markdown-formatted content',
keepStatsVisible: 'Keep processing statistics visible after generation finishes.',
@@ -143,6 +144,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
'Automatically expand tool call details while executing and keep them expanded after completion.',
pyInterpreterEnabled:
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.',
preEncodeConversation:
'After each response, re-submit the conversation to pre-fill the server KV cache. Makes the next turn faster since the prompt is already encoded while you read the response.',
enableContinueGeneration:
'Enable "Continue" button for assistant messages. Currently works only with non-reasoning models.'
};
@@ -52,6 +52,8 @@ export const SETTINGS_KEYS = {
ALWAYS_SHOW_AGENTIC_TURNS: 'alwaysShowAgenticTurns',
AGENTIC_MAX_TOOL_PREVIEW_LINES: 'agenticMaxToolPreviewLines',
SHOW_TOOL_CALL_IN_PROGRESS: 'showToolCallInProgress',
// Performance
PRE_ENCODE_CONVERSATION: 'preEncodeConversation',
// Developer
DISABLE_REASONING_PARSING: 'disableReasoningParsing',
EXCLUDE_REASONING_FROM_CONTEXT: 'excludeReasoningFromContext',
@@ -4,7 +4,8 @@ import { isAbortError } from '$lib/utils/abort';
import {
ATTACHMENT_LABEL_PDF_FILE,
ATTACHMENT_LABEL_MCP_PROMPT,
ATTACHMENT_LABEL_MCP_RESOURCE
ATTACHMENT_LABEL_MCP_RESOURCE,
LEGACY_AGENTIC_REGEX
} from '$lib/constants';
import {
AttachmentType,
@@ -279,6 +280,107 @@ export class ChatService {
}
}
/**
* Checks whether all server slots are currently idle (not processing any requests).
* Queries the /slots endpoint (requires --slots flag on the server).
* Returns true if all slots are idle, false if any is processing.
* If the endpoint is unavailable or errors out, returns true (best-effort fallback).
*
* @param signal - Optional AbortSignal to cancel the request if needed
* @param model - Optional model name to check slots for (required in ROUTER mode)
* @returns {Promise<boolean>} Promise that resolves to true if all slots are idle, false if any is processing
*/
static async areAllSlotsIdle(model?: string | null, signal?: AbortSignal): Promise<boolean> {
try {
const url = model ? `./slots?model=${encodeURIComponent(model)}` : './slots';
const res = await fetch(url, { signal });
if (!res.ok) return true;
const slots: { is_processing: boolean }[] = await res.json();
return slots.every((s) => !s.is_processing);
} catch {
return true;
}
}
/**
* Sends a fire-and-forget request to pre-encode the conversation in the server's KV cache.
* After a response completes, this re-submits the full conversation
* using n_predict=0 and stream=false so the server processes the prompt without generating tokens.
* This warms the cache for the next turn, making it faster.
*
* When excludeReasoningFromContext is true, reasoning content is stripped from the messages
* to match what sendMessage would send on the next turn (avoiding cache misses).
* When false, reasoning_content is preserved so the cached prompt matches the next request.
*
* @param messages - The full conversation including the latest assistant response
* @param model - Optional model name (required in ROUTER mode)
* @param excludeReasoning - Whether to strip reasoning content (should match excludeReasoningFromContext setting)
* @param signal - Optional AbortSignal to cancel the pre-encode request
*/
static async preEncode(
messages: ApiChatMessageData[] | (DatabaseMessage & { extra?: DatabaseMessageExtra[] })[],
model?: string | null,
excludeReasoning?: boolean,
signal?: AbortSignal
): Promise<void> {
const normalizedMessages: ApiChatMessageData[] = messages
.map((msg) => {
if ('id' in msg && 'convId' in msg && 'timestamp' in msg) {
return ChatService.convertDbMessageToApiChatMessageData(
msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] }
);
}
return msg as ApiChatMessageData;
})
.filter((msg) => {
if (msg.role === MessageRole.SYSTEM) {
const content = typeof msg.content === 'string' ? msg.content : '';
return content.trim().length > 0;
}
return true;
});
const requestBody: Record<string, unknown> = {
messages: normalizedMessages.map((msg: ApiChatMessageData) => {
const mapped: Record<string, unknown> = {
role: msg.role,
content: excludeReasoning ? ChatService.stripReasoningContent(msg.content) : msg.content,
tool_calls: msg.tool_calls,
tool_call_id: msg.tool_call_id
};
if (!excludeReasoning && msg.reasoning_content) {
mapped.reasoning_content = msg.reasoning_content;
}
return mapped;
}),
stream: false,
n_predict: 0
};
if (model) {
requestBody.model = model;
}
try {
await fetch(`./v1/chat/completions`, {
method: 'POST',
headers: getJsonHeaders(),
body: JSON.stringify(requestBody),
signal
});
} catch (error) {
if (!isAbortError(error)) {
console.warn('[ChatService] Pre-encode request failed:', error);
}
}
}
/**
*
*
@@ -799,6 +901,28 @@ export class ChatService {
*
*/
/**
* Strips legacy inline reasoning content tags from message content.
* Handles both plain string content and multipart content arrays.
*/
private static stripReasoningContent(
content: string | ApiChatMessageContentPart[]
): string | ApiChatMessageContentPart[] {
const stripFromString = (text: string): string =>
text.replace(LEGACY_AGENTIC_REGEX.REASONING_BLOCK, '').trim();
if (typeof content === 'string') {
return stripFromString(content);
}
return content.map((part) => {
if (part.type === ContentPartType.TEXT && part.text) {
return { ...part, text: stripFromString(part.text) };
}
return part;
});
}
/**
* Parses error response and creates appropriate error with context information
* @param response - HTTP response object
@@ -88,6 +88,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
},
{ key: 'max_tokens', serverKey: 'max_tokens', type: SyncableParameterType.NUMBER, canSync: true },
{ key: 'samplers', serverKey: 'samplers', type: SyncableParameterType.STRING, canSync: true },
{
key: 'backend_sampling',
serverKey: 'backend_sampling',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'pasteLongTextToFileLen',
serverKey: 'pasteLongTextToFileLen',
@@ -58,6 +58,7 @@ class ChatStore {
chatLoadingStates = new SvelteMap<string, boolean>();
chatStreamingStates = new SvelteMap<string, { response: string; messageId: string }>();
private abortControllers = new SvelteMap<string, AbortController>();
private preEncodeAbortController: AbortController | null = null;
private processingStates = new SvelteMap<string, ApiProcessingState | null>();
private conversationStateTimestamps = new SvelteMap<string, ConversationStateEntry>();
private activeConversationId = $state<string | null>(null);
@@ -462,6 +463,9 @@ class ChatStore {
const activeConv = conversationsStore.activeConversation;
if (activeConv && this.isChatLoadingInternal(activeConv.id)) return;
// Cancel any in-flight pre-encode request
this.cancelPreEncode();
// Consume MCP resource attachments - converts them to extras and clears the live store
const resourceExtras = mcpStore.consumeResourceAttachmentsAsExtras();
const allExtras = resourceExtras.length > 0 ? [...(extras || []), ...resourceExtras] : extras;
@@ -724,6 +728,16 @@ class ChatStore {
if (onComplete) onComplete(streamedContent);
if (isRouterMode()) modelsStore.fetchRouterModels().catch(console.error);
// Pre-encode conversation in KV cache for faster next turn
if (config().preEncodeConversation) {
this.triggerPreEncode(
allMessages,
assistantMessage,
streamedContent,
effectiveModel,
!!config().excludeReasoningFromContext
);
}
},
onError: (error: Error) => {
this.setStreamingActive(false);
@@ -911,6 +925,7 @@ class ChatStore {
async regenerateMessage(messageId: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isChatLoadingInternal(activeConv.id)) return;
this.cancelPreEncode();
const result = this.getMessageByIdWithRole(messageId, MessageRole.ASSISTANT);
if (!result) return;
const { index: messageIndex } = result;
@@ -940,6 +955,7 @@ class ChatStore {
async regenerateMessageWithBranching(messageId: string, modelOverride?: string): Promise<void> {
const activeConv = conversationsStore.activeConversation;
if (!activeConv || this.isChatLoadingInternal(activeConv.id)) return;
this.cancelPreEncode();
try {
const idx = conversationsStore.findMessageIndex(messageId);
if (idx === -1) return;
@@ -1610,13 +1626,48 @@ class ChatStore {
if (currentConfig.samplers) apiOptions.samplers = currentConfig.samplers;
if (currentConfig.backend_sampling)
apiOptions.backend_sampling = currentConfig.backend_sampling;
apiOptions.backend_sampling = currentConfig.backend_sampling;
if (currentConfig.custom) apiOptions.custom = currentConfig.custom;
return apiOptions;
}
private cancelPreEncode(): void {
if (this.preEncodeAbortController) {
this.preEncodeAbortController.abort();
this.preEncodeAbortController = null;
}
}
private async triggerPreEncode(
allMessages: DatabaseMessage[],
assistantMessage: DatabaseMessage,
assistantContent: string,
model?: string | null,
excludeReasoning?: boolean
): Promise<void> {
this.cancelPreEncode();
this.preEncodeAbortController = new AbortController();
const signal = this.preEncodeAbortController.signal;
try {
const allIdle = await ChatService.areAllSlotsIdle(model, signal);
if (!allIdle || signal.aborted) return;
const messagesWithAssistant: DatabaseMessage[] = [
...allMessages,
{ ...assistantMessage, content: assistantContent }
];
await ChatService.preEncode(messagesWithAssistant, model, excludeReasoning, signal);
} catch (err) {
if (!isAbortError(err)) {
console.warn('[ChatStore] Pre-encode failed:', err);
}
}
}
}
export const chatStore = new ChatStore();
@@ -77,6 +77,11 @@
!modelsStore.isModelLoaded(modelsStore.selectedModelName)
) {
modelsStore.clearSelection();
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
if (first) {
await modelsStore.selectModelById(first.id);
}
}
// Handle URL params only if we have ?q= or ?model= or ?new_chat=true