mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 09:37:42 +02:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0ec191e1d7 | |||
| 243532e556 | |||
| 5e9c635463 | |||
| 9949ad08f6 | |||
| 3ee9da0e4f | |||
| 75511a8d7e | |||
| b54cb2e3d0 | |||
| 8a65a7a8ee | |||
| 8a132faaa0 | |||
| 4293919068 | |||
| d12cc3d1ca | |||
| 2dcb7f74ed | |||
| 660600081f | |||
| d9a12c82f0 | |||
| 4a05e0c566 | |||
| e9fd96283d | |||
| 3ba12fed0a | |||
| 5473949070 | |||
| dcdcbad42a | |||
| 5764d7c6a6 |
+5
-13
@@ -75,21 +75,13 @@ android:
|
||||
- examples/llama.android/**
|
||||
server/webui:
|
||||
- changed-files:
|
||||
- all:
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/webui/**
|
||||
- tools/server/public/**
|
||||
- all-globs-to-all-files:
|
||||
- '!tools/server/webui/**'
|
||||
- '!tools/server/public/**'
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/webui/**
|
||||
- tools/server/public/**
|
||||
server:
|
||||
- changed-files:
|
||||
- all:
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/**
|
||||
- all-globs-to-all-files:
|
||||
- '!tools/server/webui/**'
|
||||
- '!tools/server/public/**'
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/**
|
||||
|
||||
|
||||
|
||||
|
||||
+11
-1
@@ -591,6 +591,10 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
for (const auto & f : files) {
|
||||
if (gguf_filename_is_model(f.path) &&
|
||||
std::regex_search(f.path, pattern)) {
|
||||
auto split = get_gguf_split_info(f.path);
|
||||
if (split.count > 1 && split.index != 1) {
|
||||
continue;
|
||||
}
|
||||
return f;
|
||||
}
|
||||
}
|
||||
@@ -600,6 +604,10 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
if (tag.empty()) {
|
||||
for (const auto & f : files) {
|
||||
if (gguf_filename_is_model(f.path)) {
|
||||
auto split = get_gguf_split_info(f.path);
|
||||
if (split.count > 1 && split.index != 1) {
|
||||
continue;
|
||||
}
|
||||
return f;
|
||||
}
|
||||
}
|
||||
@@ -618,6 +626,7 @@ static void list_available_gguf_files(const hf_cache::hf_files & files) {
|
||||
}
|
||||
|
||||
struct hf_plan {
|
||||
hf_cache::hf_file primary;
|
||||
hf_cache::hf_files model_files;
|
||||
hf_cache::hf_file mmproj;
|
||||
};
|
||||
@@ -663,6 +672,7 @@ static hf_plan get_hf_plan(const common_params_model & model,
|
||||
}
|
||||
}
|
||||
|
||||
plan.primary = primary;
|
||||
plan.model_files = get_split_files(all, primary);
|
||||
|
||||
if (opts.download_mmproj) {
|
||||
@@ -749,7 +759,7 @@ common_download_model_result common_download_model(const common_params_model
|
||||
for (const auto & f : hf.model_files) {
|
||||
hf_cache::finalize_file(f);
|
||||
}
|
||||
result.model_path = hf.model_files[0].final_path;
|
||||
result.model_path = hf.primary.final_path;
|
||||
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
|
||||
|
||||
@@ -251,6 +251,23 @@ value binary_expression::execute_impl(context & ctx) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// Python-style string repetition
|
||||
// TODO: support array/tuple repetition (e.g., [1, 2] * 3 → [1, 2, 1, 2, 1, 2])
|
||||
if (op.value == "*" &&
|
||||
((is_val<value_string>(left_val) && is_val<value_int>(right_val)) ||
|
||||
(is_val<value_int>(left_val) && is_val<value_string>(right_val)))) {
|
||||
const auto & str = is_val<value_string>(left_val) ? left_val->as_string() : right_val->as_string();
|
||||
const int64_t repeat = is_val<value_int>(right_val) ? right_val->as_int() : left_val->as_int();
|
||||
auto res = mk_val<value_string>();
|
||||
if (repeat <= 0) {
|
||||
return res;
|
||||
}
|
||||
for (int64_t i = 0; i < repeat; ++i) {
|
||||
res->val_str = res->val_str.append(str);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// String membership
|
||||
if (is_val<value_string>(left_val) && is_val<value_string>(right_val)) {
|
||||
// case: "a" in "abc"
|
||||
|
||||
+90
-3
@@ -1,4 +1,5 @@
|
||||
#include "runtime.h"
|
||||
#include "unicode.h"
|
||||
#include "value.h"
|
||||
|
||||
// for converting from JSON to jinja values
|
||||
@@ -154,6 +155,83 @@ static value test_compare_fn(const func_args & args) {
|
||||
return mk_val<value_bool>(value_compare(args.get_pos(0), args.get_pos(1), op));
|
||||
}
|
||||
|
||||
static void append_codepoint_as_ascii_json_escape(std::string & out, uint32_t codepoint) {
|
||||
auto append_u16 = [&out](uint32_t value) {
|
||||
char buf[8];
|
||||
snprintf(buf, sizeof(buf), "\\u%04x", static_cast<unsigned int>(value));
|
||||
out += buf;
|
||||
};
|
||||
|
||||
if (codepoint <= 0xFFFF) {
|
||||
append_u16(codepoint);
|
||||
return;
|
||||
}
|
||||
|
||||
codepoint -= 0x10000;
|
||||
append_u16(0xD800 + ((codepoint >> 10) & 0x3FF));
|
||||
append_u16(0xDC00 + (codepoint & 0x3FF));
|
||||
}
|
||||
|
||||
static std::string json_ensure_ascii_preserving_format(const std::string & json_str) {
|
||||
std::string output;
|
||||
output.reserve(json_str.size());
|
||||
|
||||
bool in_string = false;
|
||||
bool escaped = false;
|
||||
|
||||
for (size_t pos = 0; pos < json_str.size();) {
|
||||
const char ch = json_str[pos];
|
||||
if (!in_string) {
|
||||
output.push_back(ch);
|
||||
if (ch == '"') {
|
||||
in_string = true;
|
||||
}
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (escaped) {
|
||||
output.push_back(ch);
|
||||
escaped = false;
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch == '\\') {
|
||||
output.push_back(ch);
|
||||
escaped = true;
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ch == '"') {
|
||||
output.push_back(ch);
|
||||
in_string = false;
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
const unsigned char uch = static_cast<unsigned char>(ch);
|
||||
if (uch < 0x80) {
|
||||
output.push_back(ch);
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto parsed = common_parse_utf8_codepoint(json_str, pos);
|
||||
if (parsed.status != utf8_parse_result::SUCCESS) {
|
||||
output += "\\ufffd";
|
||||
++pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
append_codepoint_as_ascii_json_escape(output, parsed.codepoint);
|
||||
pos += parsed.bytes_consumed;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
static value tojson(const func_args & args) {
|
||||
args.ensure_count(1, 5);
|
||||
value val_ascii = args.get_kwarg_or_pos("ensure_ascii", 1);
|
||||
@@ -169,16 +247,17 @@ static value tojson(const func_args & args) {
|
||||
if (is_val<value_int>(val_indent)) {
|
||||
indent = static_cast<int>(val_indent->as_int());
|
||||
}
|
||||
if (val_ascii->as_bool()) { // undefined == false
|
||||
throw not_implemented_exception("tojson ensure_ascii=true not implemented");
|
||||
}
|
||||
if (val_sort->as_bool()) { // undefined == false
|
||||
throw not_implemented_exception("tojson sort_keys=true not implemented");
|
||||
}
|
||||
const bool ensure_ascii = val_ascii->as_bool(); // undefined == false
|
||||
auto separators = (is_val<value_array>(val_separators) ? val_separators : mk_val<value_array>())->as_array();
|
||||
std::string item_sep = separators.size() > 0 ? separators[0]->as_string().str() : (indent < 0 ? ", " : ",");
|
||||
std::string key_sep = separators.size() > 1 ? separators[1]->as_string().str() : ": ";
|
||||
std::string json_str = value_to_json(args.get_pos(0), indent, item_sep, key_sep);
|
||||
if (ensure_ascii) {
|
||||
json_str = json_ensure_ascii_preserving_format(json_str);
|
||||
}
|
||||
return mk_val<value_string>(json_str);
|
||||
}
|
||||
|
||||
@@ -460,6 +539,10 @@ const func_builtins & value_int_t::get_builtins() const {
|
||||
int64_t val = args.get_pos(0)->as_int();
|
||||
return mk_val<value_int>(val < 0 ? -val : val);
|
||||
}},
|
||||
{"int", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_int>();
|
||||
return mk_val<value_int>(args.get_pos(0)->as_int());
|
||||
}},
|
||||
{"float", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_int>();
|
||||
double val = static_cast<double>(args.get_pos(0)->as_int());
|
||||
@@ -486,6 +569,10 @@ const func_builtins & value_float_t::get_builtins() const {
|
||||
int64_t val = static_cast<int64_t>(args.get_pos(0)->as_float());
|
||||
return mk_val<value_int>(val);
|
||||
}},
|
||||
{"float", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_float>();
|
||||
return mk_val<value_float>(args.get_pos(0)->as_float());
|
||||
}},
|
||||
{"safe", tojson},
|
||||
{"string", tojson},
|
||||
{"tojson", tojson},
|
||||
|
||||
@@ -1173,7 +1173,11 @@ struct ggml_cuda_graph {
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool warmup_complete = false;
|
||||
std::vector<ggml_tensor> nodes_copy;
|
||||
struct node_properties {
|
||||
ggml_tensor node;
|
||||
void * node_src_data_ptrs[GGML_MAX_SRC];
|
||||
};
|
||||
std::vector<node_properties> node_props;
|
||||
|
||||
bool is_enabled() const {
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
|
||||
@@ -2979,18 +2979,25 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
// Check if the graph size has changed
|
||||
if ((int)graph->nodes_copy.size() != cgraph->n_nodes) {
|
||||
if ((int)graph->node_props.size() != cgraph->n_nodes) {
|
||||
res = true;
|
||||
graph->nodes_copy.resize(cgraph->n_nodes);
|
||||
graph->node_props.resize(cgraph->n_nodes);
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (!res) {
|
||||
if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph::node_properties prop = {};
|
||||
memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
|
||||
// if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see:
|
||||
// https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188
|
||||
for (int j = 0; j < GGML_MAX_SRC; ++j) {
|
||||
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr;
|
||||
}
|
||||
memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor));
|
||||
|
||||
if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
|
||||
res = true;
|
||||
}
|
||||
graph->node_props[i] = prop;
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
@@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
{
|
||||
nsg = N_SG_Q1_0;
|
||||
nr0 = N_R0_Q1_0;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nsg = N_SG_Q4_0;
|
||||
@@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
|
||||
smem = 32*sizeof(float)*nr0;
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
{
|
||||
nsg = N_SG_Q1_0;
|
||||
nr0 = N_R0_Q1_0;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nsg = N_SG_Q4_0;
|
||||
|
||||
@@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
//
|
||||
// TODO: for optimal performance, become function of the device and work size
|
||||
|
||||
#define N_R0_Q1_0 8
|
||||
#define N_SG_Q1_0 2
|
||||
|
||||
#define N_R0_Q4_0 4
|
||||
#define N_SG_Q4_0 2
|
||||
|
||||
|
||||
@@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
|
||||
op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_BF16 ||
|
||||
op->src[0]->type == GGML_TYPE_Q1_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_0 ||
|
||||
|
||||
@@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * qs = xb->qs;
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
|
||||
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
|
||||
const uint8_t b0 = qs[byte_offset];
|
||||
const uint8_t b1 = qs[byte_offset + 1];
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
|
||||
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
|
||||
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
|
||||
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
|
||||
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
|
||||
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
|
||||
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
|
||||
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
|
||||
|
||||
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
|
||||
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
|
||||
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
|
||||
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
|
||||
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
|
||||
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
|
||||
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
|
||||
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
const int base = il * 4;
|
||||
const uint8_t byte = xb->qs[base / 8];
|
||||
const int s = base % 8;
|
||||
|
||||
float4 reg_f;
|
||||
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
|
||||
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
|
||||
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
|
||||
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
|
||||
|
||||
reg = (type4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
@@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
|
||||
float sum_abs = 0.0f;
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
sum_abs += fabs(src[j]);
|
||||
}
|
||||
dst.d = sum_abs / QK1_0;
|
||||
|
||||
for (int j = 0; j < QK1_0 / 8; j++) {
|
||||
dst.qs[j] = 0;
|
||||
}
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
if (src[j] >= 0.0f) {
|
||||
dst.qs[j / 8] |= (1 << (j % 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
@@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
|
||||
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||
device const uint8_t * qs = qb_curr->qs + il / 8;
|
||||
const uint8_t b0 = qs[0];
|
||||
const uint8_t b1 = qs[1];
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
|
||||
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
|
||||
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
|
||||
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
|
||||
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
|
||||
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
|
||||
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
|
||||
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
|
||||
|
||||
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
|
||||
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
|
||||
acc += select(0.0f, yl[10], bool(b1 & 0x04));
|
||||
acc += select(0.0f, yl[11], bool(b1 & 0x08));
|
||||
acc += select(0.0f, yl[12], bool(b1 & 0x10));
|
||||
acc += select(0.0f, yl[13], bool(b1 & 0x20));
|
||||
acc += select(0.0f, yl[14], bool(b1 & 0x40));
|
||||
acc += select(0.0f, yl[15], bool(b1 & 0x80));
|
||||
|
||||
return qb_curr->d * (2.0f * acc - sumy);
|
||||
}
|
||||
|
||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||
@@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q1_0_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
const int nb = args.ne00/QK1_0;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
|
||||
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
device const block_q1_0 * ax[nr0];
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
|
||||
}
|
||||
|
||||
float yl[16];
|
||||
float sumf[nr0] = {0.f};
|
||||
|
||||
const short ix = (tiisg/8);
|
||||
const short il = (tiisg%8)*16;
|
||||
|
||||
device const float * yb = y + ix*QK1_0 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
|
||||
float sumy = 0.f;
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
yl[i] = yb[i];
|
||||
sumy += yb[i];
|
||||
}
|
||||
|
||||
FOR_UNROLL (short row = 0; row < nr0; row++) {
|
||||
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
|
||||
}
|
||||
|
||||
yb += QK1_0 * (N_SIMDWIDTH/8);
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
|
||||
if (tiisg == 0 && first_row + row < args.ne01) {
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_q1_0_f32")]]
|
||||
kernel void kernel_mul_mv_q1_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q4_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
@@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
@@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q(
|
||||
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
||||
@@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32(
|
||||
|
||||
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
@@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
|
||||
|
||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||
@@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
@@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
@@ -9893,6 +10079,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
@@ -9916,6 +10103,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
@@ -10070,6 +10258,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
|
||||
|
||||
@@ -589,6 +589,7 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
|
||||
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
||||
ggml_free(opt_ctx->ctx_static);
|
||||
ggml_free(opt_ctx->ctx_cpu);
|
||||
ggml_free(opt_ctx->ctx_copy);
|
||||
delete opt_ctx;
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,10 @@ void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 512: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_sycl_flash_attn_ext_tile_case<512, 512>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
|
||||
@@ -67,6 +67,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
@@ -124,6 +130,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
@@ -131,134 +143,6 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
||||
if(fast_fp16_available(cc))
|
||||
return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
|
||||
@@ -1293,6 +1177,16 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
// ncols2=2 and ncols2=1 fallbacks only for cases where ncols=2 config exists (DKQ == DV).
|
||||
// For DKQ == 576, DV == 512 only GQA-optimized variants are implemented.
|
||||
if constexpr (DKQ == DV) {
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
@@ -1347,5 +1241,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
|
||||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
extern DECL_FATTN_TILE_CASE(256, 256);
|
||||
extern DECL_FATTN_TILE_CASE(512, 512);
|
||||
extern DECL_FATTN_TILE_CASE(576, 512);
|
||||
|
||||
|
||||
@@ -664,4 +664,11 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|
||||
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_F16)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_0)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_1)
|
||||
EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q8_0)
|
||||
|
||||
#endif // GGML_SYCL_FATTN_VEC_HPP
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||
FATTN_VEC_CASE(128, type_K, type_V) \
|
||||
FATTN_VEC_CASE(256, type_K, type_V) \
|
||||
FATTN_VEC_CASE(512, type_K, type_V) \
|
||||
|
||||
static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
@@ -141,6 +142,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const
|
||||
case 128:
|
||||
case 112:
|
||||
case 256:
|
||||
case 512:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
@@ -185,7 +187,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 512 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// Todo: Use the XMX kernel if possible:
|
||||
|
||||
|
||||
@@ -411,11 +411,22 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
|
||||
!g_ggml_sycl_disable_optimize) {
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
||||
|
||||
if (!g_ggml_sycl_disable_optimize) {
|
||||
// set reorder extra buffer based on supported type
|
||||
switch (tensor->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q6_K:{
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_is_quantized(tensor->type)) {
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.hpp"
|
||||
|
||||
DECL_FATTN_TILE_CASE(512, 512);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_F16);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_VEC4)
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
#if defined(A_TYPEV4)
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPEV4 data_a_v4[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
@@ -17,11 +17,11 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
#ifdef B_TYPE_VEC2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||
#ifdef B_TYPEV2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPEV2 data_b_v2[];};
|
||||
#endif
|
||||
#ifdef B_TYPE_VEC4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
#ifdef B_TYPEV4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPEV4 data_b_v4[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
@@ -41,7 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||
|
||||
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
|
||||
const FLOAT_TYPEV2 dm = vec2(data_a[ib0 + i].dm);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
||||
|
||||
@@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
|
||||
@@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
|
||||
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
|
||||
const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm);
|
||||
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
|
||||
@@ -11,8 +11,8 @@ FLOAT_TYPE get_dm(uint ib) {
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
FLOAT_TYPEV2 get_dm(uint ib) {
|
||||
return FLOAT_TYPEV2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -23,9 +23,9 @@ FLOAT_TYPE get_dm(uint ib) {
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
FLOAT_TYPEV2 get_dm(uint ib) {
|
||||
const uint ib_k = ib / 8;
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
return FLOAT_TYPEV2(data_a_packed32[ib_k].dm);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -304,7 +304,7 @@ vec2 get_dm_scale(uint ib, uint iqs) {
|
||||
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||
}
|
||||
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
|
||||
return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm);
|
||||
}
|
||||
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
@@ -422,7 +422,7 @@ vec2 get_dm(uint ib, uint iqs) {
|
||||
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
|
||||
// the -1 cancels out the bias in iq1s_grid_gpu
|
||||
return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));
|
||||
return FLOAT_TYPEV2(dl, dl * (delta - 1));
|
||||
}
|
||||
|
||||
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||
|
||||
@@ -125,8 +125,8 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
|
||||
#define SHMEM_STRIDE (BK / 2 + 1)
|
||||
#endif
|
||||
|
||||
shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPEV2 buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPEV2 buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
@@ -258,17 +258,17 @@ void main() {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
|
||||
ACC_TYPEV2 sums[WMITER * TM * WNITER * TN/2];
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC4 cache_b;
|
||||
FLOAT_TYPEV4 cache_a[WMITER * TM];
|
||||
FLOAT_TYPEV4 cache_b;
|
||||
#else
|
||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC2 cache_b;
|
||||
FLOAT_TYPEV2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPEV2 cache_b;
|
||||
#endif
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
|
||||
sums[i] = ACC_TYPEV2(0.0f, 0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
@@ -11,38 +11,38 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_BF16)
|
||||
#if LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
@@ -57,10 +57,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v0.zw);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(v1.xy);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.zw);
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -73,10 +73,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
|
||||
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPEV2(v0.zw);
|
||||
buf_a[buf_idx + 8 ] = FLOAT_TYPEV2(v1.xy);
|
||||
buf_a[buf_idx + 9 ] = FLOAT_TYPEV2(v1.zw);
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -92,8 +92,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(v.yw);
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -112,10 +112,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
|
||||
const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xz);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v1.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(v0.yw);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.yw);
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -128,8 +128,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
||||
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -147,8 +147,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -171,8 +171,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy);
|
||||
const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy);
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x),
|
||||
dl * (qs.y - hm.y));
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(dl * (qs.x - hm.x),
|
||||
dl * (qs.y - hm.y));
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -206,8 +206,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -244,8 +244,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
|
||||
const vec4 q = vec4(unpack8(qs | qh));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m));
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -267,7 +267,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303;
|
||||
const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale;
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(q.x, q.y);
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -284,8 +284,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
@@ -306,8 +306,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
@@ -332,14 +332,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -358,14 +358,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -386,14 +386,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -414,10 +414,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -436,10 +436,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
@@ -456,8 +456,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float d = float(data_a[ib].d);
|
||||
const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw);
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -468,10 +468,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
|
||||
buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF],
|
||||
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
|
||||
buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
|
||||
kvalues_iq4nl[vui >> 12]);
|
||||
buf_a[buf_idx ] = d * FLOAT_TYPEV2(kvalues_iq4nl[vui & 0xF],
|
||||
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
|
||||
buf_a[buf_idx + 8] = d * FLOAT_TYPEV2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
|
||||
kvalues_iq4nl[vui >> 12]);
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -483,10 +483,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d,
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d,
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -496,7 +496,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
@@ -505,9 +505,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
@@ -515,12 +515,12 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
const uint idx = pos_b + col * p.stride_b + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
} else if (idx_n < p.N && block + row * 2 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -531,7 +531,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
@@ -541,9 +541,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
@@ -553,14 +553,14 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
|
||||
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -72,7 +72,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm);
|
||||
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
|
||||
}
|
||||
#endif
|
||||
@@ -203,7 +203,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib_k].dm);
|
||||
buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147
|
||||
}
|
||||
}
|
||||
@@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||
}
|
||||
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
|
||||
buf_a[buf_ib].dm = FLOAT_TYPEV2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales));
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,7 +426,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
buf_b[buf_ib].ds = FLOAT_TYPEV2(data_b[ib_outer].ds[ib_inner]);
|
||||
}
|
||||
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
@@ -436,7 +436,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo
|
||||
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
||||
} else {
|
||||
if (iqs == 0) {
|
||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
|
||||
buf_b[buf_ib].ds = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
|
||||
buf_b[buf_ib].qs[iqs * 4 ] = 0;
|
||||
|
||||
@@ -8,7 +8,7 @@ struct block_a_cache {
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
#define QUANT_R_MMQ 2
|
||||
@@ -22,7 +22,7 @@ struct block_a_cache {
|
||||
struct block_a_cache {
|
||||
uint32_t qs[16/4];
|
||||
uint32_t qh;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
#define QUANT_R_MMQ 1
|
||||
@@ -43,36 +43,36 @@ struct block_a_cache {
|
||||
struct block_a_cache {
|
||||
uint32_t qs[2];
|
||||
u8vec2 scales;
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
FLOAT_TYPEV2 d_scales;
|
||||
};
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define QUANT_R_MMQ 2
|
||||
struct block_a_cache {
|
||||
uint32_t qs[4];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 dm;
|
||||
FLOAT_TYPEV2 dm;
|
||||
};
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define QUANT_R_MMQ 1
|
||||
struct block_a_cache {
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 d_scales;
|
||||
FLOAT_TYPEV2 d_scales;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct block_b_cache
|
||||
{
|
||||
int32_t qs[8];
|
||||
FLOAT_TYPE_VEC2 ds;
|
||||
FLOAT_TYPEV2 ds;
|
||||
};
|
||||
|
||||
@@ -446,8 +446,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
base_dict["FLOAT16"] = "1";
|
||||
}
|
||||
|
||||
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
|
||||
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPEV2"] = f16acc ? "f16vec2" : "vec2";
|
||||
if (f16acc) {
|
||||
base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
}
|
||||
@@ -514,10 +514,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
};
|
||||
|
||||
const std::map<std::string, std::string> float_type_dict_f16 = {
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
|
||||
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")},
|
||||
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")},
|
||||
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")},
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
|
||||
{"FLOAT_TYPEV2", FLOAT_TYPE(2, "f16")},
|
||||
{"FLOAT_TYPEV4", FLOAT_TYPE(4, "f16")},
|
||||
{"FLOAT_TYPEV8", FLOAT_TYPE(8, "f16")},
|
||||
};
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
@@ -536,9 +536,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
|
||||
|
||||
const std::map<std::string, std::string> float_type_dict_bf16 = {
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
|
||||
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")},
|
||||
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")},
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
|
||||
{"FLOAT_TYPEV2", FLOAT_TYPE(2, "bf16")},
|
||||
{"FLOAT_TYPEV4", FLOAT_TYPE(4, "bf16")},
|
||||
};
|
||||
|
||||
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
|
||||
@@ -569,10 +569,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||
|
||||
const std::map<std::string, std::string> float_type_dict = {
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
|
||||
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)},
|
||||
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)},
|
||||
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)},
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
|
||||
{"FLOAT_TYPEV2", FLOAT_TYPE(2, tname)},
|
||||
{"FLOAT_TYPEV4", FLOAT_TYPE(4, tname)},
|
||||
{"FLOAT_TYPEV8", FLOAT_TYPE(8, tname)},
|
||||
};
|
||||
|
||||
// don't generate f32 variants for coopmat2
|
||||
@@ -676,36 +676,36 @@ void process_shaders() {
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}};
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
// mul mat vec
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
// mul mat vec with integer dot product
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") {
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -726,9 +726,9 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}});
|
||||
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}});
|
||||
|
||||
// Norms
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -4033,8 +4033,14 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
|
||||
|
||||
static ggml_backend_webgpu_reg_context ctx;
|
||||
static ggml_backend_reg reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_webgpu_reg_i,
|
||||
/* .context = */ &ctx,
|
||||
};
|
||||
|
||||
ctx.name = GGML_WEBGPU_NAME;
|
||||
ctx.device_count = 1;
|
||||
ctx.device_count = 0;
|
||||
|
||||
wgpu::InstanceDescriptor instance_descriptor{};
|
||||
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
||||
@@ -4053,19 +4059,28 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
|
||||
ctx.webgpu_global_ctx->instance = std::move(inst);
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
if (ctx.webgpu_global_ctx->instance == nullptr) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
|
||||
wgpu::Adapter adapter;
|
||||
if (ctx.webgpu_global_ctx->instance != nullptr) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
// probe for adapter support
|
||||
ctx.webgpu_global_ctx->instance.WaitAny(
|
||||
ctx.webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
adapter = std::move(_adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
}
|
||||
|
||||
if (adapter != nullptr) {
|
||||
ctx.device_count = 1;
|
||||
}
|
||||
|
||||
static ggml_backend_reg reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_webgpu_reg_i,
|
||||
/* .context = */ &ctx,
|
||||
};
|
||||
return ®
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,111 @@
|
||||
ied 4 ½ months
|
||||
__ggml_vocab_test__
|
||||
Äpfel
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
this is 🦙.cpp
|
||||
__ggml_vocab_test__
|
||||
w048 7tuijk dsdfhu
|
||||
__ggml_vocab_test__
|
||||
нещо на Български
|
||||
__ggml_vocab_test__
|
||||
កាន់តែពិសេសអាចខលចេញ
|
||||
__ggml_vocab_test__
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
(
|
||||
__ggml_vocab_test__
|
||||
|
||||
=
|
||||
__ggml_vocab_test__
|
||||
' era
|
||||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
__ggml_vocab_test__
|
||||
333
|
||||
__ggml_vocab_test__
|
||||
3333
|
||||
__ggml_vocab_test__
|
||||
33333
|
||||
__ggml_vocab_test__
|
||||
333333
|
||||
__ggml_vocab_test__
|
||||
3333333
|
||||
__ggml_vocab_test__
|
||||
33333333
|
||||
__ggml_vocab_test__
|
||||
333333333
|
||||
__ggml_vocab_test__
|
||||
Cửa Việt
|
||||
__ggml_vocab_test__
|
||||
discards
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
|
||||
@@ -0,0 +1,46 @@
|
||||
1178 236743 236812 47041 3794
|
||||
239122 22744 535
|
||||
|
||||
236743
|
||||
138
|
||||
139
|
||||
255968
|
||||
107
|
||||
108
|
||||
109
|
||||
255968 107
|
||||
9259 1902
|
||||
26352 1902
|
||||
9259 4109
|
||||
26352 4109
|
||||
26352 4109 236888
|
||||
9259 236764 1902 236888
|
||||
26352 236764 1902 236888
|
||||
672 563 236743 478 397 404 391 236761 12362
|
||||
236765 236771 236812 236828 236743 236832 11372 12065 31806 3405 9360
|
||||
1337 12515 1333 4632 165543 3830
|
||||
234889 63031 219876 66212 239077 237907 144494
|
||||
242015 568 7382 236768 236743 247717 237243 248989 238178 568 43819 111730 150567 236768 113452 568 8960 64334 600 815 1061 1852 8369 236768
|
||||
9259
|
||||
26352
|
||||
138 9259
|
||||
139 9259
|
||||
140 9259
|
||||
140 9259 107 140 9259
|
||||
568
|
||||
107 578
|
||||
236789 6933
|
||||
9259 236764 570 236789 712 236888 2088 659 611 170124 2360 62133 237075 17641 11700 236770 236800 236770 236812 236770 236810 236770 237471 238352
|
||||
123947
|
||||
236800
|
||||
236800 236800
|
||||
236800 236800 236800
|
||||
236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800 236800 236800 236800
|
||||
236800 236800 236800 236800 236800 236800 236800 236800 236800
|
||||
236780 29719 33154
|
||||
2243 2206
|
||||
107 236743 108 236743 109 236743 255968 236743 255969 236743 255968 107 138 107 139 107 140 107 141 107 242015 568 7382 236768 236743 247717 237243 248989 238178 568 43819 111730 150567 236768 113452 236743 478 397 404 391 478 397 404 391 236743 236800 236743 236800 236800 236743 236800 236800 236800 236743 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236800 236800 236743 236800 236800 236800 236800 236800 236800 236800 236800 236743 236800 236761 236800 236743 236800 856 236800 236743 236800 1390 236800 90986 92814 63031 219876 66212 241702 2360 62133 237075 17641 11700 236770 236800 236770 236812 236770 236810 236770 237471 238352 80448 120697 210119 1333 4632 165543 3830 9451 159561 2629 2629 2717 84491 19938 123947 38950 10371 564 236789 560 1010 756 151812 668 236789 236751 993 236764 756 1357 611 2889 236881 756 236792 711 2889 564 236789 859 1386 625 236764 756 236796 611 1133 1070 11115 236881 1191 236789 32541 496 236789 95635
|
||||
+16
-16
@@ -558,20 +558,20 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
// example: https://github.com/ggml-org/llama.cpp/pull/17548
|
||||
//
|
||||
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer)
|
||||
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
||||
{LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
||||
{LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
||||
@@ -708,9 +708,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
// altup / laurel (gemma 3n)
|
||||
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
||||
@@ -2942,7 +2942,7 @@ llama_context * llama_init_from_model(
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
}
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
|
||||
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_k);
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
|
||||
@@ -2953,7 +2953,7 @@ llama_context * llama_init_from_model(
|
||||
}
|
||||
}
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
|
||||
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_v);
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
|
||||
|
||||
+10
-9
@@ -4211,13 +4211,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
|
||||
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
|
||||
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0);
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
@@ -4276,9 +4277,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
if (n_embd_per_layer > 0) {
|
||||
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_per_layer * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_per_layer}, 0);
|
||||
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0);
|
||||
}
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
+1
-1
@@ -534,9 +534,9 @@ struct llama_model {
|
||||
struct ggml_tensor * conv1d_b = nullptr;
|
||||
|
||||
// gemma3n altup
|
||||
struct ggml_tensor * tok_embd_per_layer = nullptr;
|
||||
struct ggml_tensor * altup_proj = nullptr;
|
||||
struct ggml_tensor * altup_unembd_proj = nullptr;
|
||||
struct ggml_tensor * per_layer_tok_embd = nullptr;
|
||||
struct ggml_tensor * per_layer_model_proj = nullptr;
|
||||
struct ggml_tensor * per_layer_proj_norm = nullptr;
|
||||
|
||||
|
||||
+40
-3
@@ -659,8 +659,17 @@ struct llm_tokenizer_bpe_session {
|
||||
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
for (auto j = str.begin(); j != str.end(); ++j) {
|
||||
std::string byte_str(1, *j);
|
||||
auto token_multibyte = vocab.text_to_token(byte_str);
|
||||
llama_token token_multibyte = LLAMA_TOKEN_NULL;
|
||||
if (tokenizer.byte_encode) {
|
||||
std::string byte_str(1, *j);
|
||||
token_multibyte = vocab.text_to_token(byte_str);
|
||||
} else {
|
||||
// For non-byte-encoded BPE (e.g. gemma-4), byte tokens use <0xXX> format
|
||||
static const char * hex = "0123456789ABCDEF";
|
||||
const uint8_t ch = (uint8_t)*j;
|
||||
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
|
||||
token_multibyte = vocab.text_to_token(buf);
|
||||
}
|
||||
if (token_multibyte != LLAMA_TOKEN_NULL) {
|
||||
output.push_back(token_multibyte);
|
||||
}
|
||||
@@ -2558,7 +2567,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "[EOS]" // Kimi-K2
|
||||
|| t.first == "<|end_of_text|>"
|
||||
|| t.first == "<end_of_utterance>" // smoldocling
|
||||
|| t.first == "<turn|>" // gemma4
|
||||
|| t.first == "<eos>" // gemma4
|
||||
|| t.first == "<turn|>" // gemma4
|
||||
|| t.first == "<|tool_response>" // gemma4
|
||||
|| t.first == "<|end▁of▁sentence|>" // deepseek-ocr
|
||||
) {
|
||||
@@ -2645,6 +2655,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// workaround for gemma4 and paddleocr: do not include </s> as an eog token
|
||||
{
|
||||
bool has_tool_response = false;
|
||||
bool has_s = false;
|
||||
|
||||
llama_token s_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
for (auto tid : special_eog_ids) {
|
||||
const auto & text = id_to_token[tid].text;
|
||||
if (text == "<|tool_response>") {
|
||||
has_tool_response = true;
|
||||
} else if (text == "</s>") {
|
||||
has_s = true;
|
||||
s_id = tid;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_tool_response && has_s) {
|
||||
special_eog_ids.erase(s_id);
|
||||
|
||||
auto & attr = id_to_token[s_id].attr;
|
||||
attr = LLAMA_TOKEN_ATTR_NORMAL;
|
||||
|
||||
LLAMA_LOG_WARN("%s: special_eog_ids contains '<|tool_response>', removing '</s>' token from EOG list\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// build special tokens cache
|
||||
|
||||
+36
-32
@@ -1,5 +1,12 @@
|
||||
#include "models.h"
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model),
|
||||
@@ -22,8 +29,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
|
||||
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
||||
ggml_tensor * inp_per_layer = build_inp_per_layer();
|
||||
ggml_build_forward_expand(gf, inp_per_layer);
|
||||
|
||||
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
|
||||
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
|
||||
|
||||
// inpL now has only 1 altup, project it to the rest of the altups
|
||||
// these "added" altups will be concat to the last dim of inpL
|
||||
@@ -37,8 +47,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
|
||||
cb(inpL, "inp_stacked", -1);
|
||||
}
|
||||
// inpL now has shape: [n_embd, n_tokens, n_altup]
|
||||
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
|
||||
// inpL now has shape: [n_embd, n_tokens, n_altup]
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
|
||||
@@ -49,8 +58,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
|
||||
|
||||
// predicted value will go through self-attention and laurel
|
||||
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
|
||||
cur = active_prediction;
|
||||
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens]
|
||||
cur = active_prediction;
|
||||
cb(cur, "active_prediction", il);
|
||||
|
||||
// norm
|
||||
@@ -151,12 +160,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
|
||||
ggml_tensor * first_prediction; // [n_embd, n_tokens]
|
||||
{
|
||||
first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
|
||||
first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens]
|
||||
first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
|
||||
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
|
||||
first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
|
||||
cb(first_prediction, "first_prediction_gated", il);
|
||||
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
|
||||
|
||||
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens]
|
||||
first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
|
||||
cb(first_prediction, "first_prediction_scaled", il);
|
||||
|
||||
@@ -167,7 +177,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
}
|
||||
// equivalent to python code: corrected_predictions[1:] += first_prediction
|
||||
{
|
||||
ggml_tensor * slice_first = view_2d_slice(corrected, 0);
|
||||
ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0);
|
||||
ggml_tensor * slice_rest = ggml_view_3d(
|
||||
ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
|
||||
ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
|
||||
@@ -185,7 +195,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
|
||||
// cur now has multiple altup(s), we want to merge them back to 1 altup
|
||||
{
|
||||
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
|
||||
ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens]
|
||||
// do a view to skip the first slice (active altup)
|
||||
ggml_tensor * alt_slice =
|
||||
ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
|
||||
@@ -197,9 +207,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
cb(altup_unembd, "altup_unembd", -1);
|
||||
|
||||
// equivalent to torch.mean(hidden_states, dim=0)
|
||||
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
|
||||
cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens]
|
||||
for (int i = 0; i < n_altup - 1; ++i) {
|
||||
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
|
||||
cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i));
|
||||
}
|
||||
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
|
||||
cb(cur, "unembd_merged", -1);
|
||||
@@ -235,23 +245,16 @@ ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
|
||||
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
|
||||
}
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_altup, n_layer, n_tokens]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
ggml_tensor * inp_per_layer;
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.per_layer_tok_embd, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
|
||||
cb(inp_per_layer, "inp_per_layer_selected", -1);
|
||||
@@ -259,10 +262,10 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
} else {
|
||||
// Vision embedding path: use padding token (ID=0) embedding
|
||||
// TODO: verify if this is the correct behavior in transformers implementation
|
||||
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer
|
||||
const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer
|
||||
|
||||
// Extract and dequantize padding token embedding (row 0)
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
|
||||
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
|
||||
|
||||
// Reshape to [n_embd_altup, n_layer, 1]
|
||||
@@ -275,18 +278,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
// equivalent to project_per_layer_inputs() in python code
|
||||
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
|
||||
// output shape: [n_embd_altup, n_tokens, n_layer]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
||||
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
|
||||
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
|
||||
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
|
||||
|
||||
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
|
||||
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS,
|
||||
-1); // [n_embd_altup, n_layer, n_tokens]
|
||||
ggml_tensor * per_layer_proj;
|
||||
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
|
||||
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
|
||||
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(per_layer_proj, "per_layer_proj", -1);
|
||||
|
||||
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
|
||||
cb(inp_per_layer, "inp_per_layer", -1);
|
||||
|
||||
@@ -337,7 +341,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso
|
||||
// input cur shape: [n_embd, n_tokens, n_altup]
|
||||
// output shape: [n_embd, n_tokens, n_altup]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
|
||||
ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
|
||||
ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens]
|
||||
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
||||
cb(modalities, "modalities", il);
|
||||
|
||||
@@ -365,7 +369,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, g
|
||||
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
||||
cb(modalities, "modalities", il);
|
||||
|
||||
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
|
||||
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act);
|
||||
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
|
||||
cb(innovation, "innovation", il);
|
||||
|
||||
|
||||
+37
-28
@@ -1,5 +1,12 @@
|
||||
#include "models.h"
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model),
|
||||
@@ -19,14 +26,17 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
ggml_tensor * inp_per_layer = nullptr;
|
||||
if (model.tok_embd_per_layer) {
|
||||
inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * inp_per_layer = nullptr;
|
||||
if (model.per_layer_tok_embd) {
|
||||
inp_per_layer = build_inp_per_layer();
|
||||
ggml_build_forward_expand(gf, inp_per_layer);
|
||||
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(il);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il));
|
||||
@@ -196,7 +206,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
|
||||
|
||||
cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens]
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_per_layer, n_tokens]
|
||||
|
||||
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
|
||||
|
||||
// TODO @ngxson : improve this
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -248,34 +259,30 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
ggml_tensor * llm_build_gemma4_iswa::view_2d_slice(ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_per_layer, n_layer, n_tokens]
|
||||
ggml_tensor * llm_build_gemma4_iswa::get_per_layer_inputs() {
|
||||
ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
|
||||
ggml_tensor * inp_per_layer;
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
|
||||
inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_per_layer));
|
||||
inp_per_layer = ggml_scale (ctx0, inp_per_layer, sqrtf((float) n_embd_per_layer));
|
||||
cb(inp_per_layer, "inp_per_layer_selected", -1);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
} else {
|
||||
// Vision embedding path: use padding token (ID=0) embedding
|
||||
// TODO: verify if this is the correct behavior in transformers implementation
|
||||
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_per_layer * n_layer
|
||||
const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_per_layer * n_layer
|
||||
|
||||
// Extract and dequantize padding token embedding (row 0)
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
|
||||
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
|
||||
|
||||
// Reshape to [n_embd_per_layer, n_layer, 1]
|
||||
@@ -287,21 +294,23 @@ ggml_tensor * llm_build_gemma4_iswa::get_per_layer_inputs() {
|
||||
|
||||
// equivalent to project_per_layer_inputs() in python code
|
||||
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
|
||||
// inputs_embeds shape: [n_embd, n_tokens]
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from get_per_layer_inputs)
|
||||
// inp_batch shape: [n_embd, n_tokens]
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer)
|
||||
// output shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
||||
ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
|
||||
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
|
||||
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
|
||||
|
||||
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
|
||||
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens);
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS,
|
||||
-1); // [n_embd_per_layer, n_layer, n_tokens]
|
||||
// note: this matrix multiplication will be performed in the input layer (i.e. on the CPU)
|
||||
ggml_tensor * per_layer_proj;
|
||||
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
|
||||
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens);
|
||||
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(per_layer_proj, "per_layer_proj", -1);
|
||||
|
||||
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
|
||||
cb(inp_per_layer, "inp_per_layer", -1);
|
||||
|
||||
|
||||
+9
-6
@@ -256,9 +256,11 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
|
||||
llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * calc_magnitude(ggml_tensor * x);
|
||||
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
|
||||
ggml_tensor * get_per_layer_inputs();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
|
||||
|
||||
// TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER]
|
||||
ggml_tensor * build_inp_per_layer();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer);
|
||||
|
||||
ggml_tensor * gaussian_topk(ggml_tensor * x);
|
||||
ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il);
|
||||
ggml_tensor * altup_predict(ggml_tensor * cur, int il);
|
||||
@@ -272,9 +274,10 @@ struct llm_build_gemma4_iswa : public llm_graph_context {
|
||||
const int64_t n_embd_per_layer;
|
||||
|
||||
llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
|
||||
ggml_tensor * get_per_layer_inputs();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
|
||||
|
||||
// TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER]
|
||||
ggml_tensor * build_inp_per_layer();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer);
|
||||
};
|
||||
|
||||
struct llm_build_gemma_embedding : public llm_graph_context {
|
||||
|
||||
@@ -124,6 +124,7 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r ARGS ${PROJE
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-coder ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-coder.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-llm ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-llm.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-falcon ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-falcon.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-gemma-4 ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gemma-4.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gpt-2.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-bpe ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-bpe.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-spm ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-spm.gguf)
|
||||
|
||||
@@ -7251,6 +7251,7 @@ static const ggml_type all_types[] = {
|
||||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q1_0,
|
||||
GGML_TYPE_MXFP4, GGML_TYPE_NVFP4,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
@@ -7275,6 +7276,7 @@ static const ggml_type other_types[] = {
|
||||
GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q1_0,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
|
||||
+1
-1
@@ -3454,7 +3454,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
},
|
||||
"replaceAll": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to replace all occurences."
|
||||
"description": "Whether to replace all occurrences."
|
||||
}
|
||||
},
|
||||
"required": ["oldString", "newString"]
|
||||
|
||||
@@ -447,6 +447,18 @@ static void test_expressions(testing & t) {
|
||||
"hello world"
|
||||
);
|
||||
|
||||
test_template(t, "string repetition",
|
||||
"{{ 'ab' * 3 }}",
|
||||
json::object(),
|
||||
"ababab"
|
||||
);
|
||||
|
||||
test_template(t, "reversed string repetition",
|
||||
"{{ 3 * 'ab' }}",
|
||||
json::object(),
|
||||
"ababab"
|
||||
);
|
||||
|
||||
test_template(t, "ternary",
|
||||
"{{ 'yes' if cond else 'no' }}",
|
||||
{{"cond", true}},
|
||||
@@ -693,6 +705,33 @@ static void test_filters(testing & t) {
|
||||
"\"\\u2713\""
|
||||
);
|
||||
|
||||
test_template(t, "tojson ensure_ascii=true nested object",
|
||||
"{{ data|tojson(ensure_ascii=true) }}",
|
||||
{{"data", {
|
||||
{"text", "\u2713"},
|
||||
{"items", json::array({"é", {{"snowman", "☃"}}})}
|
||||
}}},
|
||||
"{\"text\": \"\\u2713\", \"items\": [\"\\u00e9\", {\"snowman\": \"\\u2603\"}]}"
|
||||
);
|
||||
|
||||
test_template(t, "tojson ensure_ascii=true indent=2",
|
||||
"{{ data|tojson(ensure_ascii=true, indent=2) }}",
|
||||
{{"data", {
|
||||
{"text", "\u2713"},
|
||||
{"nested", {{"accent", "é"}}}
|
||||
}}},
|
||||
"{\n \"text\": \"\\u2713\",\n \"nested\": {\n \"accent\": \"\\u00e9\"\n }\n}"
|
||||
);
|
||||
|
||||
test_template(t, "tojson ensure_ascii=true preserves existing escapes",
|
||||
"{{ data|tojson(ensure_ascii=true) }}",
|
||||
{{"data", {
|
||||
{"emoji", "😀"},
|
||||
{"line", "a\nb"}
|
||||
}}},
|
||||
"{\"emoji\": \"\\ud83d\\ude00\", \"line\": \"a\\nb\"}"
|
||||
);
|
||||
|
||||
test_template(t, "tojson sort_keys=true",
|
||||
"{{ data|tojson(sort_keys=true) }}",
|
||||
{{"data", {{"b", 2}, {"a", 1}}}},
|
||||
@@ -771,6 +810,12 @@ static void test_filters(testing & t) {
|
||||
"hello"
|
||||
);
|
||||
|
||||
test_template(t, "int filter on integer is identity",
|
||||
"{{ value|int }}",
|
||||
{{"value", 7}},
|
||||
"7"
|
||||
);
|
||||
|
||||
test_template(t, "none to string",
|
||||
"{{ x|string }}",
|
||||
{{"x", nullptr}},
|
||||
@@ -2458,4 +2503,12 @@ static void test_fuzzing(testing & t) {
|
||||
t.assert_true("builtin " + type_name + "." + fn_name + " #" + std::to_string(i), fuzz_test_template(tmpl, vars));
|
||||
}
|
||||
});
|
||||
|
||||
t.test("tojson ensure_ascii=true with invalid utf-8", [&](testing & t) {
|
||||
t.assert_true("invalid utf-8 does not crash",
|
||||
fuzz_test_template(
|
||||
"{{ data|tojson(ensure_ascii=true) }}",
|
||||
{{"data", std::string("hello\xfe\xffworld")}}
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -18,7 +18,7 @@
|
||||
<div style="display: contents">
|
||||
<script>
|
||||
{
|
||||
__sveltekit_1ppa22i = {
|
||||
__sveltekit_6n4hpv = {
|
||||
base: new URL('.', location).pathname.slice(0, -1)
|
||||
};
|
||||
|
||||
|
||||
@@ -3033,6 +3033,8 @@ server_context_meta server_context::get_meta() const {
|
||||
/* fim_rep_token */ llama_vocab_fim_rep(impl->vocab),
|
||||
/* fim_sep_token */ llama_vocab_fim_sep(impl->vocab),
|
||||
|
||||
/* logit_bias_eog */ impl->params_base.sampling.logit_bias_eog,
|
||||
|
||||
/* model_vocab_type */ llama_vocab_type(impl->vocab),
|
||||
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
|
||||
/* model_n_ctx_train */ llama_model_n_ctx_train(impl->model),
|
||||
@@ -3117,6 +3119,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
ctx_server.vocab,
|
||||
params,
|
||||
meta->slot_n_ctx,
|
||||
meta->logit_bias_eog,
|
||||
data);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
|
||||
@@ -39,6 +39,9 @@ struct server_context_meta {
|
||||
llama_token fim_rep_token;
|
||||
llama_token fim_sep_token;
|
||||
|
||||
// sampling
|
||||
std::vector<llama_logit_bias> logit_bias_eog;
|
||||
|
||||
// model meta
|
||||
enum llama_vocab_type model_vocab_type;
|
||||
int32_t model_vocab_n_tokens;
|
||||
|
||||
@@ -239,6 +239,7 @@ task_params server_task::params_from_json_cmpl(
|
||||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
const int n_ctx_slot,
|
||||
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||
const json & data) {
|
||||
task_params params;
|
||||
|
||||
@@ -383,6 +384,8 @@ task_params server_task::params_from_json_cmpl(
|
||||
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar = defaults.sampling.grammar;
|
||||
|
||||
std::string grammar_str = json_value(data, "grammar", std::string());
|
||||
if (!grammar_str.empty()) {
|
||||
// grammar_type key is set by the server when converting chat template grammars
|
||||
@@ -562,7 +565,7 @@ task_params server_task::params_from_json_cmpl(
|
||||
if (params.sampling.ignore_eos) {
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
|
||||
logit_bias_eog.begin(), logit_bias_eog.end());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -209,6 +209,7 @@ struct server_task {
|
||||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
const int n_ctx_slot,
|
||||
const std::vector<llama_logit_bias> & logit_bias_eog,
|
||||
const json & data);
|
||||
|
||||
// utility function
|
||||
|
||||
@@ -135,7 +135,7 @@ def test_completion_stream_with_openai_library_stops():
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.completions.create(
|
||||
model="davinci-002",
|
||||
prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
|
||||
prompt="System: You are helpful assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
|
||||
stop=["User:\n", "Assistant:\n"],
|
||||
max_tokens=200,
|
||||
stream=True,
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
def test_ignore_eos_populates_logit_bias():
|
||||
"""ignore_eos=true must add EOG logit biases to generation_settings."""
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "Once upon a time",
|
||||
"ignore_eos": True,
|
||||
"temperature": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
# EOG token biases must be present with -inf bias
|
||||
logit_bias = res.body["generation_settings"]["logit_bias"]
|
||||
assert len(logit_bias) > 0
|
||||
for entry in logit_bias:
|
||||
assert entry["bias"] is None # null in JSON represents -inf
|
||||
|
||||
|
||||
def test_ignore_eos_false_no_logit_bias():
|
||||
"""ignore_eos=false (default) must NOT add EOG logit biases."""
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "Once upon a time",
|
||||
"ignore_eos": False,
|
||||
"temperature": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
logit_bias = res.body["generation_settings"]["logit_bias"]
|
||||
assert len(logit_bias) == 0
|
||||
+5
-1
@@ -62,10 +62,14 @@
|
||||
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
|
||||
);
|
||||
|
||||
let lastSyncedConversationModel: string | null = null;
|
||||
|
||||
$effect(() => {
|
||||
if (conversationModel) {
|
||||
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
|
||||
lastSyncedConversationModel = conversationModel;
|
||||
modelsStore.selectModelByName(conversationModel);
|
||||
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
|
||||
lastSyncedConversationModel = null;
|
||||
// auto-select the first loaded model only when nothing is selected yet
|
||||
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
|
||||
if (first) modelsStore.selectModelById(first.id);
|
||||
|
||||
@@ -291,14 +291,19 @@
|
||||
title: SETTINGS_SECTION_TITLES.DEVELOPER,
|
||||
icon: Code,
|
||||
fields: [
|
||||
{
|
||||
key: SETTINGS_KEYS.PRE_ENCODE_CONVERSATION,
|
||||
label: 'Pre-fill KV cache after response',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.DISABLE_REASONING_PARSING,
|
||||
label: 'Disable reasoning content parsing',
|
||||
label: 'Disable server-side thinking extraction',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.EXCLUDE_REASONING_FROM_CONTEXT,
|
||||
label: 'Exclude reasoning from context',
|
||||
label: 'Strip thinking from message history',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
|
||||
@@ -56,6 +56,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean |
|
||||
dry_penalty_last_n: undefined,
|
||||
max_tokens: undefined,
|
||||
custom: '', // custom json-stringified object
|
||||
preEncodeConversation: false,
|
||||
// experimental features
|
||||
pyInterpreterEnabled: false,
|
||||
enableContinueGeneration: false
|
||||
@@ -106,9 +107,9 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
custom: 'Custom JSON parameters to send to the API. Must be valid JSON format.',
|
||||
showThoughtInProgress: 'Expand thought process by default when generating messages.',
|
||||
disableReasoningParsing:
|
||||
'Send reasoning_format=none to prevent server-side extraction of reasoning tokens into separate field',
|
||||
'Send reasoning_format=none so the server returns thinking tokens inline instead of extracting them into a separate field.',
|
||||
excludeReasoningFromContext:
|
||||
'Strip reasoning content from previous messages before sending to the model. When unchecked, reasoning is sent back via the reasoning_content field so the model can see its own chain-of-thought across turns.',
|
||||
'Strip thinking from previous messages before sending. When off, thinking is sent back via the reasoning_content field so the model sees its own chain-of-thought across turns.',
|
||||
showRawOutputSwitch:
|
||||
'Show toggle button to display messages as plain text instead of Markdown-formatted content',
|
||||
keepStatsVisible: 'Keep processing statistics visible after generation finishes.',
|
||||
@@ -143,6 +144,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
'Automatically expand tool call details while executing and keep them expanded after completion.',
|
||||
pyInterpreterEnabled:
|
||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.',
|
||||
preEncodeConversation:
|
||||
'After each response, re-submit the conversation to pre-fill the server KV cache. Makes the next turn faster since the prompt is already encoded while you read the response.',
|
||||
enableContinueGeneration:
|
||||
'Enable "Continue" button for assistant messages. Currently works only with non-reasoning models.'
|
||||
};
|
||||
|
||||
@@ -52,6 +52,8 @@ export const SETTINGS_KEYS = {
|
||||
ALWAYS_SHOW_AGENTIC_TURNS: 'alwaysShowAgenticTurns',
|
||||
AGENTIC_MAX_TOOL_PREVIEW_LINES: 'agenticMaxToolPreviewLines',
|
||||
SHOW_TOOL_CALL_IN_PROGRESS: 'showToolCallInProgress',
|
||||
// Performance
|
||||
PRE_ENCODE_CONVERSATION: 'preEncodeConversation',
|
||||
// Developer
|
||||
DISABLE_REASONING_PARSING: 'disableReasoningParsing',
|
||||
EXCLUDE_REASONING_FROM_CONTEXT: 'excludeReasoningFromContext',
|
||||
|
||||
@@ -4,7 +4,8 @@ import { isAbortError } from '$lib/utils/abort';
|
||||
import {
|
||||
ATTACHMENT_LABEL_PDF_FILE,
|
||||
ATTACHMENT_LABEL_MCP_PROMPT,
|
||||
ATTACHMENT_LABEL_MCP_RESOURCE
|
||||
ATTACHMENT_LABEL_MCP_RESOURCE,
|
||||
LEGACY_AGENTIC_REGEX
|
||||
} from '$lib/constants';
|
||||
import {
|
||||
AttachmentType,
|
||||
@@ -279,6 +280,107 @@ export class ChatService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether all server slots are currently idle (not processing any requests).
|
||||
* Queries the /slots endpoint (requires --slots flag on the server).
|
||||
* Returns true if all slots are idle, false if any is processing.
|
||||
* If the endpoint is unavailable or errors out, returns true (best-effort fallback).
|
||||
*
|
||||
* @param signal - Optional AbortSignal to cancel the request if needed
|
||||
* @param model - Optional model name to check slots for (required in ROUTER mode)
|
||||
* @returns {Promise<boolean>} Promise that resolves to true if all slots are idle, false if any is processing
|
||||
*/
|
||||
static async areAllSlotsIdle(model?: string | null, signal?: AbortSignal): Promise<boolean> {
|
||||
try {
|
||||
const url = model ? `./slots?model=${encodeURIComponent(model)}` : './slots';
|
||||
const res = await fetch(url, { signal });
|
||||
if (!res.ok) return true;
|
||||
|
||||
const slots: { is_processing: boolean }[] = await res.json();
|
||||
return slots.every((s) => !s.is_processing);
|
||||
} catch {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a fire-and-forget request to pre-encode the conversation in the server's KV cache.
|
||||
* After a response completes, this re-submits the full conversation
|
||||
* using n_predict=0 and stream=false so the server processes the prompt without generating tokens.
|
||||
* This warms the cache for the next turn, making it faster.
|
||||
*
|
||||
* When excludeReasoningFromContext is true, reasoning content is stripped from the messages
|
||||
* to match what sendMessage would send on the next turn (avoiding cache misses).
|
||||
* When false, reasoning_content is preserved so the cached prompt matches the next request.
|
||||
*
|
||||
* @param messages - The full conversation including the latest assistant response
|
||||
* @param model - Optional model name (required in ROUTER mode)
|
||||
* @param excludeReasoning - Whether to strip reasoning content (should match excludeReasoningFromContext setting)
|
||||
* @param signal - Optional AbortSignal to cancel the pre-encode request
|
||||
*/
|
||||
static async preEncode(
|
||||
messages: ApiChatMessageData[] | (DatabaseMessage & { extra?: DatabaseMessageExtra[] })[],
|
||||
model?: string | null,
|
||||
excludeReasoning?: boolean,
|
||||
signal?: AbortSignal
|
||||
): Promise<void> {
|
||||
const normalizedMessages: ApiChatMessageData[] = messages
|
||||
.map((msg) => {
|
||||
if ('id' in msg && 'convId' in msg && 'timestamp' in msg) {
|
||||
return ChatService.convertDbMessageToApiChatMessageData(
|
||||
msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] }
|
||||
);
|
||||
}
|
||||
|
||||
return msg as ApiChatMessageData;
|
||||
})
|
||||
.filter((msg) => {
|
||||
if (msg.role === MessageRole.SYSTEM) {
|
||||
const content = typeof msg.content === 'string' ? msg.content : '';
|
||||
|
||||
return content.trim().length > 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
const requestBody: Record<string, unknown> = {
|
||||
messages: normalizedMessages.map((msg: ApiChatMessageData) => {
|
||||
const mapped: Record<string, unknown> = {
|
||||
role: msg.role,
|
||||
content: excludeReasoning ? ChatService.stripReasoningContent(msg.content) : msg.content,
|
||||
tool_calls: msg.tool_calls,
|
||||
tool_call_id: msg.tool_call_id
|
||||
};
|
||||
|
||||
if (!excludeReasoning && msg.reasoning_content) {
|
||||
mapped.reasoning_content = msg.reasoning_content;
|
||||
}
|
||||
|
||||
return mapped;
|
||||
}),
|
||||
stream: false,
|
||||
n_predict: 0
|
||||
};
|
||||
|
||||
if (model) {
|
||||
requestBody.model = model;
|
||||
}
|
||||
|
||||
try {
|
||||
await fetch(`./v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: getJsonHeaders(),
|
||||
body: JSON.stringify(requestBody),
|
||||
signal
|
||||
});
|
||||
} catch (error) {
|
||||
if (!isAbortError(error)) {
|
||||
console.warn('[ChatService] Pre-encode request failed:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
@@ -799,6 +901,28 @@ export class ChatService {
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Strips legacy inline reasoning content tags from message content.
|
||||
* Handles both plain string content and multipart content arrays.
|
||||
*/
|
||||
private static stripReasoningContent(
|
||||
content: string | ApiChatMessageContentPart[]
|
||||
): string | ApiChatMessageContentPart[] {
|
||||
const stripFromString = (text: string): string =>
|
||||
text.replace(LEGACY_AGENTIC_REGEX.REASONING_BLOCK, '').trim();
|
||||
|
||||
if (typeof content === 'string') {
|
||||
return stripFromString(content);
|
||||
}
|
||||
|
||||
return content.map((part) => {
|
||||
if (part.type === ContentPartType.TEXT && part.text) {
|
||||
return { ...part, text: stripFromString(part.text) };
|
||||
}
|
||||
return part;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses error response and creates appropriate error with context information
|
||||
* @param response - HTTP response object
|
||||
|
||||
@@ -88,6 +88,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
||||
},
|
||||
{ key: 'max_tokens', serverKey: 'max_tokens', type: SyncableParameterType.NUMBER, canSync: true },
|
||||
{ key: 'samplers', serverKey: 'samplers', type: SyncableParameterType.STRING, canSync: true },
|
||||
{
|
||||
key: 'backend_sampling',
|
||||
serverKey: 'backend_sampling',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
serverKey: 'pasteLongTextToFileLen',
|
||||
|
||||
@@ -58,6 +58,7 @@ class ChatStore {
|
||||
chatLoadingStates = new SvelteMap<string, boolean>();
|
||||
chatStreamingStates = new SvelteMap<string, { response: string; messageId: string }>();
|
||||
private abortControllers = new SvelteMap<string, AbortController>();
|
||||
private preEncodeAbortController: AbortController | null = null;
|
||||
private processingStates = new SvelteMap<string, ApiProcessingState | null>();
|
||||
private conversationStateTimestamps = new SvelteMap<string, ConversationStateEntry>();
|
||||
private activeConversationId = $state<string | null>(null);
|
||||
@@ -462,6 +463,9 @@ class ChatStore {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (activeConv && this.isChatLoadingInternal(activeConv.id)) return;
|
||||
|
||||
// Cancel any in-flight pre-encode request
|
||||
this.cancelPreEncode();
|
||||
|
||||
// Consume MCP resource attachments - converts them to extras and clears the live store
|
||||
const resourceExtras = mcpStore.consumeResourceAttachmentsAsExtras();
|
||||
const allExtras = resourceExtras.length > 0 ? [...(extras || []), ...resourceExtras] : extras;
|
||||
@@ -724,6 +728,16 @@ class ChatStore {
|
||||
|
||||
if (onComplete) onComplete(streamedContent);
|
||||
if (isRouterMode()) modelsStore.fetchRouterModels().catch(console.error);
|
||||
// Pre-encode conversation in KV cache for faster next turn
|
||||
if (config().preEncodeConversation) {
|
||||
this.triggerPreEncode(
|
||||
allMessages,
|
||||
assistantMessage,
|
||||
streamedContent,
|
||||
effectiveModel,
|
||||
!!config().excludeReasoningFromContext
|
||||
);
|
||||
}
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
this.setStreamingActive(false);
|
||||
@@ -911,6 +925,7 @@ class ChatStore {
|
||||
async regenerateMessage(messageId: string): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isChatLoadingInternal(activeConv.id)) return;
|
||||
this.cancelPreEncode();
|
||||
const result = this.getMessageByIdWithRole(messageId, MessageRole.ASSISTANT);
|
||||
if (!result) return;
|
||||
const { index: messageIndex } = result;
|
||||
@@ -940,6 +955,7 @@ class ChatStore {
|
||||
async regenerateMessageWithBranching(messageId: string, modelOverride?: string): Promise<void> {
|
||||
const activeConv = conversationsStore.activeConversation;
|
||||
if (!activeConv || this.isChatLoadingInternal(activeConv.id)) return;
|
||||
this.cancelPreEncode();
|
||||
try {
|
||||
const idx = conversationsStore.findMessageIndex(messageId);
|
||||
if (idx === -1) return;
|
||||
@@ -1610,13 +1626,48 @@ class ChatStore {
|
||||
|
||||
if (currentConfig.samplers) apiOptions.samplers = currentConfig.samplers;
|
||||
|
||||
if (currentConfig.backend_sampling)
|
||||
apiOptions.backend_sampling = currentConfig.backend_sampling;
|
||||
apiOptions.backend_sampling = currentConfig.backend_sampling;
|
||||
|
||||
if (currentConfig.custom) apiOptions.custom = currentConfig.custom;
|
||||
|
||||
return apiOptions;
|
||||
}
|
||||
|
||||
private cancelPreEncode(): void {
|
||||
if (this.preEncodeAbortController) {
|
||||
this.preEncodeAbortController.abort();
|
||||
this.preEncodeAbortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
private async triggerPreEncode(
|
||||
allMessages: DatabaseMessage[],
|
||||
assistantMessage: DatabaseMessage,
|
||||
assistantContent: string,
|
||||
model?: string | null,
|
||||
excludeReasoning?: boolean
|
||||
): Promise<void> {
|
||||
this.cancelPreEncode();
|
||||
this.preEncodeAbortController = new AbortController();
|
||||
|
||||
const signal = this.preEncodeAbortController.signal;
|
||||
|
||||
try {
|
||||
const allIdle = await ChatService.areAllSlotsIdle(model, signal);
|
||||
if (!allIdle || signal.aborted) return;
|
||||
|
||||
const messagesWithAssistant: DatabaseMessage[] = [
|
||||
...allMessages,
|
||||
{ ...assistantMessage, content: assistantContent }
|
||||
];
|
||||
|
||||
await ChatService.preEncode(messagesWithAssistant, model, excludeReasoning, signal);
|
||||
} catch (err) {
|
||||
if (!isAbortError(err)) {
|
||||
console.warn('[ChatStore] Pre-encode failed:', err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const chatStore = new ChatStore();
|
||||
|
||||
@@ -77,6 +77,11 @@
|
||||
!modelsStore.isModelLoaded(modelsStore.selectedModelName)
|
||||
) {
|
||||
modelsStore.clearSelection();
|
||||
|
||||
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
|
||||
if (first) {
|
||||
await modelsStore.selectModelById(first.id);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle URL params only if we have ?q= or ?model= or ?new_chat=true
|
||||
|
||||
Reference in New Issue
Block a user