Compare commits

...

43 Commits

Author SHA1 Message Date
Diego Devesa a394039db0 ggml-cpu : add chunking support to mul_mat_id (#11666)
* ggml-cpu : add chunking support to mul_mat_id

* allocate chunk counter in wdata
parallelize src1 quantization by column to allows parallelization even when there is only one row

* disable for arm

* cleanup

* better way to disable for arm

* fix uninitialized counter when using 1 thread only

* revert test-backend-ops changes
2025-02-13 01:02:38 +01:00
Xuan-Son Nguyen be3bbd6215 ggml : x2 speed for WASM by optimizing SIMD (#11453)
* ggml : x2 speed for WASM by optimizing SIMD

* fix bad merging

* rm trailing spaces

* rm redundant clamp

* better quantize_row_q8_K

Co-authored-by: camel-cdr <camel-cdr@protonmail.com>

* remove memset that causes buffer overflow
Co-authored-by: camel-cdr <camel-cdr@protonmail.com>

---------

Co-authored-by: camel-cdr <camel-cdr@protonmail.com>
2025-02-13 00:33:45 +01:00
Woof Dog 31afcbee0e server : (webui) Give copy button back to all message bubbles (#11814)
* All messages get the copy button

* Update index.html.gz
2025-02-12 23:47:11 +01:00
uvos 5c4284d57b HIP: Remove GCN from list of devices that avoid MMQ (#11831) 2025-02-12 22:25:28 +01:00
JC bfd11a2344 Fix: Compile failure due to Microsoft STL breaking change (#11836) 2025-02-12 21:36:11 +01:00
Georgi Gerganov 0fb77f821f sync : ggml 2025-02-12 21:46:02 +02:00
uvos e598697d63 HIP: Switch to std::vector in rocblas version check (#11820) 2025-02-12 17:25:03 +01:00
bandoti fef0cbeadf cleanup: fix compile warnings associated with gnu_printf (#11811) 2025-02-12 10:06:53 -04:00
Richard 748ee9fe93 ggml : fix multi-threaded clamp_f32 (#11824)
* Bug fix for clamp_f32

When using tensors larger than 1d clamp operation does not work due to the restriction of returning if ith is not 0.

* Bug fix for clamp_f32

* Bug fix for clamp_f32
2025-02-12 15:57:33 +02:00
Weizhao Ouyang 198b1ec611 ggml-cpu: Fix duplicate MATMUL_INT8 (#11817)
Signed-off-by: Weizhao Ouyang <o451686892@gmail.com>
2025-02-12 13:22:58 +01:00
Johannes Gäßler c3d6af7cd2 CUDA: fix CUDART_VERSION checks (#11821) 2025-02-12 13:16:39 +01:00
Daniel Bevenius 369be5598a llama : fix typo in llama-grammar.h [no ci] (#11816) 2025-02-12 09:40:01 +02:00
lhez 4078c77f98 docs: add OpenCL (#11697) 2025-02-11 15:04:13 -07:00
Sheldon Robinson 90e4dba461 Fix #11802: Compile bug - RegQueryValueExA changed to RegQueryValueEx (#11803)
* Fix #11802: Compile bug - RegQueryValueExA changed to RegQueryValueEx

* Fix #11802: PR #11803 - keep RegQueryValueExA, remove TEXT macro, description needs to be ANSI string
2025-02-11 16:55:45 +01:00
Daniel Bevenius a18f481f99 server : use common_token_to_piece instead of common_detokenize (#11740)
* server : use common_token_to_piece instead of common_detokenize

This commit replaces the call to common_detokenize with
common_token_to_piece in the populate_token_probs.

The motivation for this change is to avoid an issue where
common_detokenize would remove the word boundary character for tokens,
which caused a regression in the server generated token probabilities.

Resolves: https://github.com/ggerganov/llama.cpp/issues/11728

* squash! server : use common_token_to_piece instead of common_detokenize

Use common_token_to_piece for post_sampling_probs as well.
2025-02-11 14:06:45 +01:00
Johannes Gäßler b9ab0a4d0b CUDA: use arch list for compatibility check (#11775)
* CUDA: use arch list for feature availability check

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-02-11 00:17:22 +01:00
Maxim Evtush 7b891bdc86 fix: typos in documentation files (#11791)
* Update ggml.c

* Update arg.cpp

* Update speculative.h
2025-02-10 23:21:31 +01:00
jason_w 81732619fd docs: utilize the forward slash (/) as the path separator for Unix-like systems (#11770) 2025-02-10 23:17:48 +01:00
Xuan-Son Nguyen 507f9174fe server : (webui) introduce conversation branching + idb storage (#11792)
* server : (webui) introduce conversation branching + idb storage

* mark old conv as "migrated" instead deleting them

* improve migration

* add more comments

* more clarification
2025-02-10 21:23:17 +01:00
Wilken Gottwalt 19b392d58d llama-mmap: fix missing include (#11796)
Technically the fixed width types come only from iostream and
cstdint/stdint.h headers. memory and vector headers should not provide
these. In GCC 15 the headers are cleaned up and you require the proper
header cstdint.

src/llama-mmap.h:26:5: error: ‘uint32_t’ does not name a type
   26 |     uint32_t read_u32() const;
      |     ^~~~~~~~
2025-02-10 20:58:18 +02:00
Xuan-Son Nguyen 0893e0114e server : correct signal handler (#11795) 2025-02-10 18:03:28 +01:00
Olivier Chafik d7b31a9d84 sync: minja (https://github.com/google/minja/commit/a72057e5190de2c612d4598bb10b4bfd0f53011f) (#11774) 2025-02-10 09:34:09 +00:00
pascal-lc 9ac3457b39 Update README.md [no ci] (#11781)
typo: `\` -> `/`
Change the UNIX path separator to` \`.
2025-02-10 09:05:57 +01:00
Danny Milosavljevic c2a67efe38 vulkan: Make Vulkan optional at runtime (#11493). (#11494)
Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
2025-02-10 07:17:21 +01:00
Wagner Bruna b044a0fe3c vulkan: add environment variable GGML_VK_PREFER_HOST_MEMORY to avoid VRAM allocation (#11592) 2025-02-10 07:08:22 +01:00
Eric Curtin 19d3c8293b There's a better way of clearing lines (#11756)
Use the ANSI escape code for clearing a line.

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
2025-02-09 10:34:49 +00:00
Jeff Bolz 98f6b0fd1e vulkan: account for lookup tables when checking shared memory size (#11502) 2025-02-09 08:43:51 +01:00
Xuan-Son Nguyen 55ac8c7791 server : (webui) revamp Settings dialog, add Pyodide interpreter (#11759)
* redo Settings modal UI

* add python code interpreter

* fix auto scroll

* build

* fix overflow for long output lines

* bring back sticky copy button

* adapt layout on mobile view

* fix multiple lines output and color scheme

* handle python exception

* better state management

* add webworker

* add headers

* format code

* speed up by loading pyodide on page load

* (small tweak) add small animation to make it feels like claude
2025-02-08 21:54:50 +01:00
Woof Dog e6e6583199 server : (webui) increase edit textarea size (#11763) 2025-02-08 20:09:55 +01:00
Georgi Gerganov aaa5505307 server : minor log updates (#11760)
ggml-ci
2025-02-08 18:08:43 +02:00
Georgi Gerganov bdcf8b6a56 cont : fix mmap flag print (#11699) 2025-02-08 16:49:38 +02:00
Karol Kontny 4d3465c5ae ggml: Fix data race in ggml threadpool (#11736)
After the barrier in last iteration is executed, still the loop termination
condition will be executed. However main thread can destroy the cgraph object
and its nodes already, then another thread will access it, but the thing is already gone.
Also trouble can happen when n_nodes == 0 or abort is called, but I'm not sure if the
prior situation is possible.

Last syncronization should be done after the loop to ensure the cgraph/cplan won't be
accessed after the main thread exits from the function.
2025-02-08 15:30:53 +01:00
Johannes Gäßler d80be897ac CUDA: fix min. version for movmatrix (#11751) 2025-02-08 10:46:07 +01:00
Nikolaos Pothitos 3ab410f55f readme : update front-end framework (#11753)
After the migration to React with #11688
2025-02-08 10:43:04 +01:00
Xuan-Son Nguyen 0cf867160c server : (webui) fix numeric settings being saved as string (#11739)
* server : (webui) fix numeric settings being saved as string

* add some more comments
2025-02-08 10:42:34 +01:00
Eric Curtin d2fe216fb2 Make logging more verbose (#11714)
Debugged an issue with a user who was on a read-only filesystem.

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
2025-02-07 14:42:46 +00:00
Georgi Gerganov ed926d8833 llama : fix defrag logic (#11707)
* llama : fix defrag logic

ggml-ci

* cont : better logic

ggml-ci

* cont : clamp fragmentation to 0.0

ggml-ci
2025-02-07 16:05:34 +02:00
Christian Fillion 2d219b389e vocab : ignore invalid UTF-8 input in the BPE tokenizer (#11729)
Silently insert U+FFFD(s) (Unicode replacement character) instead until the
next valid codepoint can be found.

This fixes `llama_tokenize` throwing an exception across the C API boundary
or libllama's module boundary (the caller's runtime might be incompatible!)

Returing a proper error code might be desirable, however the signature
of `llama_tokenize` doesn't allow it as all return values already have
existing meaning.
2025-02-07 15:55:47 +02:00
magicse 333820d749 llama : fix progress dots (#11730)
* Update llama.cpp

For display progress dots in terminal.
Without this it didn't display dots progress during loading model from file.

* Update llama.cpp

removed trailing spaces
2025-02-07 15:48:47 +02:00
Jeff Bolz c026ba3c23 vulkan: print shared memory size (#11719) 2025-02-07 11:26:03 +01:00
Christian Fillion 7ee953a64a llama : add llama_sampler_init for safe usage of llama_sampler_free (#11727)
The C API in llama.h claims users can implement `llama_sampler_i` to
create custom `llama_sampler`. The sampler chain takes ownership and
calls `llama_sampler_free` on them. However, `llama_sampler_free` is
hard-coded to use `delete`. This is undefined behavior if the object
wasn't also allocated via `new` from libllama's C++ runtime. Callers
in C and C-compatible languages do not use C++'s `new` operator. C++
callers may not be sharing the same heap as libllama.
2025-02-07 11:33:27 +02:00
Akarshan Biswas ec3bc8270b SYCL: remove XMX info from print devices (#11712) 2025-02-07 09:27:53 +00:00
Daniel Bevenius b7552cfcbc common : add default embeddings presets (#11677)
* common : add default embeddings presets

This commit adds default embeddings presets for the following models:
- bge-small-en-v1.5
- e5-small-v2
- gte-small

These can be used with llama-embedding and llama-server.

For example, with llama-embedding:
```console
./build/bin/llama-embedding --embd-gte-small-default -p "Hello, how are you?"
```

And with llama-server:
```console
./build/bin/llama-server --embd-gte-small-default
```
And the embeddings endpoint can then be called with a POST request:
```console
curl --request POST \
    --url http://localhost:8080/embeddings \
    --header "Content-Type: application/json" \
    --data '{"input": "Hello, how are you?"}'
```

I'm not sure if these are the most common embedding models but hopefully
this can be a good starting point for discussion and further
improvements.

Refs: https://github.com/ggerganov/llama.cpp/issues/10932
2025-02-07 09:15:22 +01:00
61 changed files with 3297 additions and 1137 deletions
+1
View File
@@ -235,6 +235,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
| [HIP](docs/build.md#hip) | AMD GPU |
| [Vulkan](docs/build.md#vulkan) | GPU |
| [CANN](docs/build.md#cann) | Ascend NPU |
| [OpenCL](docs/backend/OPENCL.md) | Adreno GPU |
## Building the project
+43 -1
View File
@@ -674,7 +674,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
));
add_opt(common_arg(
{"--no-context-shift"},
string_format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
[](common_params & params) {
params.ctx_shift = false;
}
@@ -2324,5 +2324,47 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_TTS}));
add_opt(common_arg(
{"--embd-bge-small-en-default"},
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
params.hf_file = "bge-small-en-v1.5-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--embd-e5-small-en-default"},
string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
params.hf_file = "e5-small-v2-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--embd-gte-small-default"},
string_format("use default gte-small model (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
params.hf_file = "gte-small-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
return ctx_arg;
}
+21 -7
View File
@@ -249,16 +249,30 @@ class chat_template {
inputs.add_generation_prompt = false;
full = apply(inputs);
}
if (full.find(prefix) != 0) {
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
auto eos_pos_last = full.rfind(eos_token_);
if (eos_pos_last == prefix.size() - eos_token_.size() ||
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
full = full.substr(0, eos_pos_last);
}
size_t common_prefix_length = 0;
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
if (prefix[i] != full[i]) {
break;
}
if (prefix[i] == '<') {
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
// but it removes thinking tags for past messages.
// The prefix and full strings diverge at <think> vs. <tool▁calls▁begin>, we avoid consuming the leading <.
continue;
}
common_prefix_length = i + 1;
}
if (full.find(prefix) != 0) {
auto example = full.substr(common_prefix_length);
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
} else {
tool_call_example_ = example;
}
tool_call_example_ = full.substr(prefix.size());
}
} catch (const std::exception & e) {
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
@@ -363,7 +377,7 @@ class chat_template {
if (polyfill_tools) {
adjusted_messages = add_system(inputs.messages,
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
} else {
adjusted_messages = inputs.messages;
}
+6 -6
View File
@@ -424,13 +424,13 @@ bool set_process_priority(enum ggml_sched_priority prio);
//
#ifdef __GNUC__
#ifdef __MINGW32__
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
# if defined(__MINGW32__) && !defined(__clang__)
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
# else
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
# endif
#else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
#endif
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
+3 -3
View File
@@ -254,10 +254,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
};
}
return new llama_sampler{
return llama_sampler_init(
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx,
};
/* .ctx = */ ctx
);
}
#else
+1
View File
@@ -1,5 +1,6 @@
#include "log.h"
#include <chrono>
#include <condition_variable>
#include <cstdarg>
#include <cstdio>
+2 -1
View File
@@ -2,6 +2,7 @@
#include "ggml.h" // for ggml_log_level
#define LOG_CLR_TO_EOL "\033[K\r"
#define LOG_COL_DEFAULT "\033[0m"
#define LOG_COL_BOLD "\033[1m"
#define LOG_COL_RED "\033[31m"
@@ -14,7 +15,7 @@
#ifndef __GNUC__
# define LOG_ATTRIBUTE_FORMAT(...)
#elif defined(__MINGW32__)
#elif defined(__MINGW32__) && !defined(__clang__)
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+24 -9
View File
@@ -1385,6 +1385,13 @@ static std::string strip(const std::string & s) {
return s.substr(start, end - start + 1);
}
static std::string capitalize(const std::string & s) {
if (s.empty()) return s;
auto result = s;
result[0] = std::toupper(result[0]);
return result;
}
static std::string html_escape(const std::string & s) {
std::string result;
result.reserve(s.size());
@@ -1462,6 +1469,9 @@ public:
if (method->get_name() == "strip") {
vargs.expectArgs("strip method", {0, 0}, {0, 0});
return Value(strip(str));
} else if (method->get_name() == "capitalize") {
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
return Value(capitalize(str));
} else if (method->get_name() == "endswith") {
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
auto suffix = vargs.args[0].get<std::string>();
@@ -1792,7 +1802,7 @@ private:
auto left = parseStringConcat();
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)");
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
static std::regex not_tok(R"(not\b)");
std::string op_str;
while (!(op_str = consumeToken(compare_tok)).empty()) {
@@ -2171,7 +2181,7 @@ private:
using TemplateTokenIterator = TemplateTokenVector::const_iterator;
std::vector<std::string> parseVarNames() {
static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)");
static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
std::vector<std::string> group;
if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
@@ -2194,13 +2204,13 @@ private:
}
TemplateTokenVector tokenize() {
static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})");
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
static std::regex block_close_regex(R"(\s*([-~])?%\})");
TemplateTokenVector tokens;
std::vector<std::string> group;
@@ -2284,7 +2294,7 @@ private:
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
} else if (keyword == "set") {
static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");
static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
std::string ns;
std::vector<std::string> var_names;
@@ -2336,6 +2346,11 @@ private:
throw std::runtime_error("Unexpected block: " + keyword);
}
} else if (std::regex_search(it, end, match, non_text_open_regex)) {
if (!match.position()) {
if (match[0] != "{#")
throw std::runtime_error("Internal error: Expected a comment");
throw std::runtime_error("Missing end of comment tag");
}
auto text_end = it + match.position();
text = std::string(it, text_end);
it = text_end;
@@ -2400,7 +2415,7 @@ private:
auto text = text_token->text;
if (post_space == SpaceHandling::Strip) {
static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
static std::regex trailing_space_regex(R"(\s+$)");
text = std::regex_replace(text, trailing_space_regex, "");
} else if (options.lstrip_blocks && it != end) {
auto i = text.size();
@@ -2410,7 +2425,7 @@ private:
}
}
if (pre_space == SpaceHandling::Strip) {
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
static std::regex leading_space_regex(R"(^\s+)");
text = std::regex_replace(text, leading_space_regex, "");
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
if (text.length() > 0 && text[0] == '\n') {
+1 -1
View File
@@ -9,7 +9,7 @@ struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
float p_min = 0.9f; // min probabiliy required to accept a token in the draft
float p_min = 0.9f; // min probability required to accept a token in the draft
};
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
+205
View File
@@ -0,0 +1,205 @@
# llama.cpp for OpenCL
- [Background](#background)
- [OS](#os)
- [Hardware](#hardware)
- [DataType Supports](#datatype-supports)
- [Model Preparation](#model-preparation)
- [CMake Options](#cmake-options)
- [Android](#android)
- [Windows 11 Arm64](#windows-11-arm64)
- [Known Issue](#known-issues)
- [TODO](#todo)
## Background
OpenCL (Open Computing Language) is an open, royalty-free standard for cross-platform, parallel programming of diverse accelerators found in supercomputers, cloud servers, personal computers, mobile devices and embedded platforms. OpenCL specifies a programming language (based on C99) for programming these devices and application programming interfaces (APIs) to control the platform and execute programs on the compute devices. Similar to CUDA, OpenCL has been widely used to program GPUs and is supported by most GPU vendors.
### Llama.cpp + OpenCL
The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adreno GPU** firstly via OpenCL. Thanks to the portabilty of OpenCL, the OpenCL backend can also run on certain Intel GPUs although the performance is not optimal.
## OS
| OS | Status | Verified |
|---------|---------|------------------------------------------------|
| Android | Support | Snapdragon 8 Gen 3, Snapdragon 8 Elite |
| Windows | Support | Windows 11 Arm64 with Snapdragon X Elite |
| Linux | Support | Ubuntu 22.04 WSL2 with Intel 12700H |
## Hardware
### Adreno GPU
**Verified devices**
| Adreno GPU | Status |
|:------------------------------------:|:-------:|
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
| Adreno 830 (Snapdragon 8 Elite) | Support |
| Adreno X85 (Snapdragon X Elite) | Support |
## DataType Supports
| DataType | Status |
|:----------------------:|:--------------------------:|
| Q4_0 | Support |
| Q6_K | Support, but not optimized |
## Model Preparation
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration.
Currently we support `Q4_0` quantization and have optimize for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize`. For example,
```sh
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
```
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
## CMake Options
The OpenCL backend has the following CMake options that control the behavior of the backend.
| CMake options | Default value | Description |
|:---------------------------------:|:--------------:|:------------------------------------------|
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
## Android
Ubuntu 22.04 is used for targeting Android. Make sure the following tools are accessible from command line,
* Git
* CMake 3.29
* Ninja
* Python3
### I. Setup Environment
1. **Install NDK**
```sh
cd ~
wget https://dl.google.com/android/repository/commandlinetools-linux-8512546_latest.zip && \
unzip commandlinetools-linux-8512546_latest.zip && \
mkdir -p ~/android-sdk/cmdline-tools && \
mv cmdline-tools latest && \
mv latest ~/android-sdk/cmdline-tools/ && \
rm -rf commandlinetools-linux-8512546_latest.zip
yes | ~/android-sdk/cmdline-tools/latest/bin/sdkmanager "ndk;26.3.11579264"
```
2. **Install OpenCL Headers and Library**
```sh
mkdir -p ~/dev/llm
cd ~/dev/llm
git clone https://github.com/KhronosGroup/OpenCL-Headers && \
cd OpenCL-Headers && \
cp -r CL ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
cd ~/dev/llm
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && \
cd OpenCL-ICD-Loader && \
mkdir build_ndk26 && cd build_ndk26 && \
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release \
-DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \
-DOPENCL_ICD_LOADER_HEADERS_DIR=$HOME/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include \
-DANDROID_ABI=arm64-v8a \
-DANDROID_PLATFORM=24 \
-DANDROID_STL=c++_shared && \
ninja && \
cp libOpenCL.so ~/android-sdk/ndk/26.3.11579264/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android
```
### II. Build llama.cpp
```sh
cd ~/dev/llm
git clone https://github.com/ggerganov/llama.cpp && \
cd llama.cpp && \
mkdir build-android && cd build-android
cmake .. -G Ninja \
-DCMAKE_TOOLCHAIN_FILE=$HOME/android-sdk/ndk/26.3.11579264/build/cmake/android.toolchain.cmake \
-DANDROID_ABI=arm64-v8a \
-DANDROID_PLATFORM=android-28 \
-DBUILD_SHARED_LIBS=OFF \
-DGGML_OPENCL=ON
ninja
```
## Windows 11 Arm64
A Snapdragon X Elite device with Windows 11 Arm64 is used. Make sure the following tools are accessible from command line,
* Git
* CMake 3.29
* Clang 19
* Ninja
* Visual Studio 2022
Powershell is used for the following instructions.
### I. Setup Environment
1. **Install OpenCL Headers and Library**
```powershell
mkdir -p ~/dev/llm
cd ~/dev/llm
git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers
mkdir build && cd build
cmake .. -G Ninja `
-DBUILD_TESTING=OFF `
-DOPENCL_HEADERS_BUILD_TESTING=OFF `
-DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF `
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
cmake --build . --target install
cd ~/dev/llm
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader
mkdir build && cd build
cmake .. -G Ninja `
-DCMAKE_BUILD_TYPE=Release `
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" `
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
cmake --build . --target install
```
### II. Build llama.cpp
```powershell
mkdir -p ~/dev/llm
cd ~/dev/llm
git clone https://github.com/ggerganov/llama.cpp && cd llama.cpp
mkdir build && cd build
cmake .. -G Ninja `
-DCMAKE_TOOLCHAIN_FILE="$HOME/dev/llm/llama.cpp/cmake/arm64-windows-llvm.cmake" `
-DCMAKE_BUILD_TYPE=Release `
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" `
-DBUILD_SHARED_LIBS=OFF `
-DGGML_OPENCL=ON
ninja
```
## Known Issues
- Qwen2.5 0.5B model produces gibberish output with Adreno kernels.
## TODO
- Fix Qwen2.5 0.5B
- Optimization for Q6_K
- Support and optimization for Q4_K
+1
View File
@@ -3,6 +3,7 @@
#include "log.h"
#include "llama.h"
#include <chrono>
#include <cmath>
#include <cstdio>
#include <cstring>
+1 -1
View File
@@ -37,7 +37,7 @@ Once downloaded, place your model in the models folder in llama.cpp.
##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
```bash
./llama-cli -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
```
### Windows:
+1
View File
@@ -3,6 +3,7 @@
#include "log.h"
#include "llama.h"
#include <chrono>
#include <algorithm>
#include <array>
#include <atomic>
+5 -12
View File
@@ -346,7 +346,7 @@ class HttpClient {
if (!output_file.empty()) {
output_file_partial = output_file + ".partial";
if (!out.open(output_file_partial, "ab")) {
printe("Failed to open file\n");
printe("Failed to open file for writing\n");
return 1;
}
@@ -535,8 +535,7 @@ class HttpClient {
static void print_progress(const std::string & progress_prefix, const std::string & progress_bar,
const std::string & progress_suffix) {
printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(),
progress_suffix.c_str());
printe("\r" LOG_CLR_TO_EOL "%s%s| %s", progress_prefix.c_str(), progress_bar.c_str(), progress_suffix.c_str());
}
// Function to write data to a file
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
@@ -797,16 +796,13 @@ class LlamaData {
llama_model_ptr initialize_model(Opt & opt) {
ggml_backend_load_all();
resolve_model(opt.model_);
printe(
"\r%*s"
"\rLoading model",
get_terminal_width(), " ");
printe("\r" LOG_CLR_TO_EOL "Loading model");
llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params));
if (!model) {
printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
}
printe("\r%*s\r", static_cast<int>(sizeof("Loading model")), " ");
printe("\r" LOG_CLR_TO_EOL);
return model;
}
@@ -969,10 +965,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
static int read_user_input(std::string & user_input) {
static const char * prompt_prefix = "> ";
#ifdef WIN32
printf(
"\r%*s"
"\r" LOG_COL_DEFAULT "%s",
get_terminal_width(), " ", prompt_prefix);
printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix);
std::getline(std::cin, user_input);
if (std::cin.eof()) {
+1 -1
View File
@@ -220,7 +220,7 @@ services:
The project includes a web-based user interface that enables interaction with the model through the `/chat/completions` endpoint.
The web UI is developed using:
- `vue` framework for frontend development
- `react` framework for frontend development
- `tailwindcss` and `daisyui` for styling
- `vite` for build tooling
Binary file not shown.
+39 -31
View File
@@ -334,24 +334,24 @@ struct server_task {
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
LOG_DBG("JSON schema: %s\n", schema.dump(2).c_str());
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
params.sampling.grammar = json_schema_to_grammar(schema);
LOG_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
LOG_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
LOG_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
}
{
auto it = data.find("chat_format");
if (it != data.end()) {
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
LOG_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
} else {
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
}
@@ -367,12 +367,12 @@ struct server_task {
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
@@ -381,11 +381,11 @@ struct server_task {
for (const auto & t : *preserved_tokens) {
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Preserved token: %d\n", ids[0]);
SRV_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
}
}
}
@@ -717,7 +717,7 @@ struct server_task_result_cmpl_final : server_task_result {
std::string finish_reason = "length";
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
LOG_DBG("Parsing chat message: %s\n", content.c_str());
SRV_DBG("Parsing chat message: %s\n", content.c_str());
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
} else {
@@ -1600,6 +1600,10 @@ struct server_queue {
while (true) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
}
if (queue_tasks.empty()) {
lock.unlock();
break;
@@ -1620,11 +1624,11 @@ struct server_queue {
QUE_DBG("%s", "waiting for new tasks\n");
{
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
}
if (queue_tasks.empty()) {
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
}
condition_tasks.wait(lock, [&]{
return (!queue_tasks.empty() || !running);
});
@@ -1885,7 +1889,7 @@ struct server_context {
}
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_from_model(model, "chatml");
} else {
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
@@ -2275,7 +2279,7 @@ struct server_context {
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
result.probs.push_back({
cur_p->data[i].id,
common_detokenize(ctx, {cur_p->data[i].id}, special),
common_token_to_piece(ctx, cur_p->data[i].id, special),
cur_p->data[i].p
});
}
@@ -2297,7 +2301,7 @@ struct server_context {
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
result.probs.push_back({
cur[i].id,
common_detokenize(ctx, {cur[i].id}, special),
common_token_to_piece(ctx, cur[i].id, special),
cur[i].p
});
}
@@ -3355,10 +3359,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
LOG_DBG("request: %s\n", req.body.c_str());
LOG_DBG("response: %s\n", res.body.c_str());
SRV_DBG("request: %s\n", req.body.c_str());
SRV_DBG("response: %s\n", res.body.c_str());
}
std::function<void(int)> shutdown_handler;
@@ -3860,7 +3864,9 @@ int main(int argc, char ** argv) {
try {
const auto & prompt = data.at("prompt");
LOG_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@@ -4376,6 +4382,9 @@ int main(int argc, char ** argv) {
res.set_content("Error: gzip is not supported by this browser", "text/plain");
} else {
res.set_header("Content-Encoding", "gzip");
// COEP and COOP headers, required by pyodide (python interpreter)
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
}
return false;
@@ -4425,6 +4434,7 @@ int main(int argc, char ** argv) {
// clean up function, to be called before exit
auto clean_up = [&svr]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
svr->stop();
llama_backend_free();
};
@@ -4441,10 +4451,6 @@ int main(int argc, char ** argv) {
}
if (!was_bound) {
//LOG_ERROR("couldn't bind HTTP server socket", {
// {"hostname", params.hostname},
// {"port", params.port},
//});
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
clean_up();
return 1;
@@ -4461,7 +4467,7 @@ int main(int argc, char ** argv) {
if (!ctx_server.load_model(params)) {
clean_up();
t.join();
// t.join(); // FIXME: see below
LOG_ERR("%s: exiting due to model loading error\n", __func__);
return 1;
}
@@ -4485,13 +4491,10 @@ int main(int argc, char ** argv) {
});
shutdown_handler = [&](int) {
// this will unblock start_loop()
ctx_server.queue_tasks.terminate();
};
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
ctx_server.queue_tasks.start_loop();
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
@@ -4506,8 +4509,13 @@ int main(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
// this call blocks the main thread until queue_tasks.terminate() is called
ctx_server.queue_tasks.start_loop();
clean_up();
t.join();
// t.join(); // FIXME: http thread may stuck if there is an on-going request. we don't need to care about this for now as the HTTP connection will already be closed at this point, but it's better to fix this
return 0;
}
+17
View File
@@ -8,10 +8,12 @@
"name": "webui",
"version": "0.0.0",
"dependencies": {
"@heroicons/react": "^2.2.0",
"@sec-ant/readable-stream": "^0.6.0",
"@vscode/markdown-it-katex": "^1.1.1",
"autoprefixer": "^10.4.20",
"daisyui": "^4.12.14",
"dexie": "^4.0.11",
"highlight.js": "^11.10.0",
"katex": "^0.16.15",
"postcss": "^8.4.49",
@@ -902,6 +904,15 @@
"node": "^18.18.0 || ^20.9.0 || >=21.1.0"
}
},
"node_modules/@heroicons/react": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.2.0.tgz",
"integrity": "sha512-LMcepvRaS9LYHJGsF0zzmgKCUim/X3N/DQKc4jepAXJ7l8QxJ1PmxJzqplF2Z3FE4PqBAIGyJAQ/w4B5dsqbtQ==",
"license": "MIT",
"peerDependencies": {
"react": ">= 16 || ^19.0.0-rc"
}
},
"node_modules/@humanfs/core": {
"version": "0.19.1",
"resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz",
@@ -2328,6 +2339,12 @@
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/dexie": {
"version": "4.0.11",
"resolved": "https://registry.npmjs.org/dexie/-/dexie-4.0.11.tgz",
"integrity": "sha512-SOKO002EqlvBYYKQSew3iymBoN2EQ4BDw/3yprjh7kAfFzjBYkaMNa/pZvcA7HSWlcKSQb9XhPe3wKyQ0x4A8A==",
"license": "Apache-2.0"
},
"node_modules/didyoumean": {
"version": "1.2.2",
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
+2
View File
@@ -11,10 +11,12 @@
"preview": "vite preview"
},
"dependencies": {
"@heroicons/react": "^2.2.0",
"@sec-ant/readable-stream": "^0.6.0",
"@vscode/markdown-it-katex": "^1.1.1",
"autoprefixer": "^10.4.20",
"daisyui": "^4.12.14",
"dexie": "^4.0.11",
"highlight.js": "^11.10.0",
"katex": "^0.16.15",
"postcss": "^8.4.49",
+13 -2
View File
@@ -1,8 +1,9 @@
import { HashRouter, Outlet, Route, Routes } from 'react-router';
import Header from './components/Header';
import Sidebar from './components/Sidebar';
import { AppContextProvider } from './utils/app.context';
import { AppContextProvider, useAppContext } from './utils/app.context';
import ChatScreen from './components/ChatScreen';
import SettingDialog from './components/SettingDialog';
function App() {
return (
@@ -22,13 +23,23 @@ function App() {
}
function AppLayout() {
const { showSettings, setShowSettings } = useAppContext();
return (
<>
<Sidebar />
<div className="chat-screen drawer-content grow flex flex-col h-screen w-screen mx-auto px-4">
<div
className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto"
id="main-scroll"
>
<Header />
<Outlet />
</div>
{
<SettingDialog
show={showSettings}
onClose={() => setShowSettings(false)}
/>
}
</>
);
}
+3
View File
@@ -10,6 +10,7 @@ export const BASE_URL = new URL('.', document.baseURI).href
export const CONFIG_DEFAULT = {
// Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value.
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
apiKey: '',
systemMessage: 'You are a helpful assistant.',
showTokensPerSecond: false,
@@ -36,6 +37,8 @@ export const CONFIG_DEFAULT = {
dry_penalty_last_n: -1,
max_tokens: -1,
custom: '', // custom json-stringified object
// experimental features
pyIntepreterEnabled: false,
};
export const CONFIG_INFO: Record<string, string> = {
apiKey: 'Set the API Key if you are using --api-key option for the server.',
@@ -0,0 +1,195 @@
import { useEffect, useState } from 'react';
import { useAppContext } from '../utils/app.context';
import { OpenInNewTab, XCloseButton } from '../utils/common';
import { CanvasType } from '../utils/types';
import { PlayIcon, StopIcon } from '@heroicons/react/24/outline';
import StorageUtils from '../utils/storage';
const canInterrupt = typeof SharedArrayBuffer === 'function';
// adapted from https://pyodide.org/en/stable/usage/webworker.html
const WORKER_CODE = `
importScripts("https://cdn.jsdelivr.net/pyodide/v0.27.2/full/pyodide.js");
let stdOutAndErr = [];
let pyodideReadyPromise = loadPyodide({
stdout: (data) => stdOutAndErr.push(data),
stderr: (data) => stdOutAndErr.push(data),
});
let alreadySetBuff = false;
self.onmessage = async (event) => {
stdOutAndErr = [];
// make sure loading is done
const pyodide = await pyodideReadyPromise;
const { id, python, context, interruptBuffer } = event.data;
if (interruptBuffer && !alreadySetBuff) {
pyodide.setInterruptBuffer(interruptBuffer);
alreadySetBuff = true;
}
// Now load any packages we need, run the code, and send the result back.
await pyodide.loadPackagesFromImports(python);
// make a Python dictionary with the data from content
const dict = pyodide.globals.get("dict");
const globals = dict(Object.entries(context));
try {
self.postMessage({ id, running: true });
// Execute the python code in this context
const result = pyodide.runPython(python, { globals });
self.postMessage({ result, id, stdOutAndErr });
} catch (error) {
self.postMessage({ error: error.message, id });
}
interruptBuffer[0] = 0;
};
`;
let worker: Worker;
const interruptBuffer = canInterrupt
? new Uint8Array(new SharedArrayBuffer(1))
: null;
const startWorker = () => {
if (!worker) {
worker = new Worker(
URL.createObjectURL(new Blob([WORKER_CODE], { type: 'text/javascript' }))
);
}
};
if (StorageUtils.getConfig().pyIntepreterEnabled) {
startWorker();
}
const runCodeInWorker = (
pyCode: string,
callbackRunning: () => void
): {
donePromise: Promise<string>;
interrupt: () => void;
} => {
startWorker();
const id = Math.random() * 1e8;
const context = {};
if (interruptBuffer) {
interruptBuffer[0] = 0;
}
const donePromise = new Promise<string>((resolve) => {
worker.onmessage = (event) => {
const { error, stdOutAndErr, running } = event.data;
if (id !== event.data.id) return;
if (running) {
callbackRunning();
return;
} else if (error) {
resolve(error.toString());
} else {
resolve(stdOutAndErr.join('\n'));
}
};
worker.postMessage({ id, python: pyCode, context, interruptBuffer });
});
const interrupt = () => {
console.log('Interrupting...');
console.trace();
if (interruptBuffer) {
interruptBuffer[0] = 2;
}
};
return { donePromise, interrupt };
};
export default function CanvasPyInterpreter() {
const { canvasData, setCanvasData } = useAppContext();
const [code, setCode] = useState(canvasData?.content ?? ''); // copy to avoid direct mutation
const [running, setRunning] = useState(false);
const [output, setOutput] = useState('');
const [interruptFn, setInterruptFn] = useState<() => void>();
const [showStopBtn, setShowStopBtn] = useState(false);
const runCode = async (pycode: string) => {
interruptFn?.();
setRunning(true);
setOutput('Loading Pyodide...');
const { donePromise, interrupt } = runCodeInWorker(pycode, () => {
setOutput('Running...');
setShowStopBtn(canInterrupt);
});
setInterruptFn(() => interrupt);
const out = await donePromise;
setOutput(out);
setRunning(false);
setShowStopBtn(false);
};
// run code on mount
useEffect(() => {
setCode(canvasData?.content ?? '');
runCode(canvasData?.content ?? '');
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [canvasData?.content]);
if (canvasData?.type !== CanvasType.PY_INTERPRETER) {
return null;
}
return (
<div className="card bg-base-200 w-full h-full shadow-xl">
<div className="card-body">
<div className="flex justify-between items-center mb-4">
<span className="text-lg font-bold">Python Interpreter</span>
<XCloseButton
className="bg-base-100"
onClick={() => setCanvasData(null)}
/>
</div>
<div className="grid grid-rows-3 gap-4 h-full">
<textarea
className="textarea textarea-bordered w-full h-full font-mono"
value={code}
onChange={(e) => setCode(e.target.value)}
></textarea>
<div className="font-mono flex flex-col row-span-2">
<div className="flex items-center mb-2">
<button
className="btn btn-sm bg-base-100"
onClick={() => runCode(code)}
disabled={running}
>
<PlayIcon className="h-6 w-6" /> Run
</button>
{showStopBtn && (
<button
className="btn btn-sm bg-base-100 ml-2"
onClick={() => interruptFn?.()}
>
<StopIcon className="h-6 w-6" /> Stop
</button>
)}
<span className="grow text-right text-xs">
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/11762">
Report a bug
</OpenInNewTab>
</span>
</div>
<textarea
className="textarea textarea-bordered h-full dark-color"
value={output}
readOnly
></textarea>
</div>
</div>
</div>
</div>
);
}
@@ -3,6 +3,7 @@ import { useAppContext } from '../utils/app.context';
import { Message, PendingMessage } from '../utils/types';
import { classNames } from '../utils/misc';
import MarkdownDisplay, { CopyButton } from './MarkdownDisplay';
import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/24/outline';
interface SplitMessage {
content: PendingMessage['content'];
@@ -12,17 +13,24 @@ interface SplitMessage {
export default function ChatMessage({
msg,
siblingLeafNodeIds,
siblingCurrIdx,
id,
scrollToBottom,
onRegenerateMessage,
onEditMessage,
onChangeSibling,
isPending,
}: {
msg: Message | PendingMessage;
siblingLeafNodeIds: Message['id'][];
siblingCurrIdx: number;
id?: string;
scrollToBottom: (requiresNearBottom: boolean) => void;
onRegenerateMessage(msg: Message): void;
onEditMessage(msg: Message, content: string): void;
onChangeSibling(sibling: Message['id']): void;
isPending?: boolean;
}) {
const { viewingConversation, replaceMessageAndGenerate, config } =
useAppContext();
const { viewingChat, config } = useAppContext();
const [editingContent, setEditingContent] = useState<string | null>(null);
const timings = useMemo(
() =>
@@ -37,6 +45,8 @@ export default function ChatMessage({
: null,
[msg.timings]
);
const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1];
const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1];
// for reasoning model, we split the message into content and thought
// TODO: implement this as remark/rehype plugin in the future
@@ -64,13 +74,7 @@ export default function ChatMessage({
return { content: actualContent, thought, isThinking };
}, [msg]);
if (!viewingConversation) return null;
const regenerate = async () => {
replaceMessageAndGenerate(viewingConversation.id, msg.id, undefined, () =>
scrollToBottom(true)
);
};
if (!viewingChat) return null;
return (
<div className="group" id={id}>
@@ -92,7 +96,7 @@ export default function ChatMessage({
<>
<textarea
dir="auto"
className="textarea textarea-bordered bg-base-100 text-base-content w-[calc(90vw-8em)] lg:w-96"
className="textarea textarea-bordered bg-base-100 text-base-content max-w-2xl w-[calc(90vw-8em)] h-24"
value={editingContent}
onChange={(e) => setEditingContent(e.target.value)}
></textarea>
@@ -105,13 +109,12 @@ export default function ChatMessage({
</button>
<button
className="btn mt-2"
onClick={() =>
replaceMessageAndGenerate(
viewingConversation.id,
msg.id,
editingContent
)
}
onClick={() => {
if (msg.content !== null) {
setEditingContent(null);
onEditMessage(msg as Message, editingContent);
}
}}
>
Submit
</button>
@@ -149,11 +152,17 @@ export default function ChatMessage({
)}
</summary>
<div className="collapse-content">
<MarkdownDisplay content={thought} />
<MarkdownDisplay
content={thought}
isGenerating={isPending}
/>
</div>
</details>
)}
<MarkdownDisplay content={content} />
<MarkdownDisplay
content={content}
isGenerating={isPending}
/>
</div>
</>
)}
@@ -190,10 +199,35 @@ export default function ChatMessage({
{msg.content !== null && (
<div
className={classNames({
'mx-4 mt-2 mb-2': true,
'text-right': msg.role === 'user',
'flex items-center gap-2 mx-4 mt-2 mb-2': true,
'flex-row-reverse': msg.role === 'user',
})}
>
{siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
<div className="flex gap-1 items-center opacity-60 text-sm">
<button
className={classNames({
'btn btn-sm btn-ghost p-1': true,
'opacity-20': !prevSibling,
})}
onClick={() => prevSibling && onChangeSibling(prevSibling)}
>
<ChevronLeftIcon className="h-4 w-4" />
</button>
<span>
{siblingCurrIdx + 1} / {siblingLeafNodeIds.length}
</span>
<button
className={classNames({
'btn btn-sm btn-ghost p-1': true,
'opacity-20': !nextSibling,
})}
onClick={() => nextSibling && onChangeSibling(nextSibling)}
>
<ChevronRightIcon className="h-4 w-4" />
</button>
</div>
)}
{/* user message */}
{msg.role === 'user' && (
<button
@@ -210,18 +244,22 @@ export default function ChatMessage({
{!isPending && (
<button
className="badge btn-mini show-on-hover mr-2"
onClick={regenerate}
onClick={() => {
if (msg.content !== null) {
onRegenerateMessage(msg as Message);
}
}}
disabled={msg.content === null}
>
🔄 Regenerate
</button>
)}
<CopyButton
className="badge btn-mini show-on-hover mr-2"
content={msg.content}
/>
</>
)}
<CopyButton
className="badge btn-mini show-on-hover mr-2"
content={msg.content}
/>
</div>
)}
</div>
@@ -1,123 +1,243 @@
import { useEffect, useRef, useState } from 'react';
import { useAppContext } from '../utils/app.context';
import StorageUtils from '../utils/storage';
import { useNavigate } from 'react-router';
import { useEffect, useMemo, useState } from 'react';
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
import ChatMessage from './ChatMessage';
import { PendingMessage } from '../utils/types';
import { CanvasType, Message, PendingMessage } from '../utils/types';
import { classNames, throttle } from '../utils/misc';
import CanvasPyInterpreter from './CanvasPyInterpreter';
import StorageUtils from '../utils/storage';
/**
* A message display is a message node with additional information for rendering.
* For example, siblings of the message node are stored as their last node (aka leaf node).
*/
export interface MessageDisplay {
msg: Message | PendingMessage;
siblingLeafNodeIds: Message['id'][];
siblingCurrIdx: number;
isPending?: boolean;
}
function getListMessageDisplay(
msgs: Readonly<Message[]>,
leafNodeId: Message['id']
): MessageDisplay[] {
const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
const res: MessageDisplay[] = [];
const nodeMap = new Map<Message['id'], Message>();
for (const msg of msgs) {
nodeMap.set(msg.id, msg);
}
// find leaf node from a message node
const findLeafNode = (msgId: Message['id']): Message['id'] => {
let currNode: Message | undefined = nodeMap.get(msgId);
while (currNode) {
if (currNode.children.length === 0) break;
currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
}
return currNode?.id ?? -1;
};
// traverse the current nodes
for (const msg of currNodes) {
const parentNode = nodeMap.get(msg.parent ?? -1);
if (!parentNode) continue;
const siblings = parentNode.children;
if (msg.type !== 'root') {
res.push({
msg,
siblingLeafNodeIds: siblings.map(findLeafNode),
siblingCurrIdx: siblings.indexOf(msg.id),
});
}
}
return res;
}
const scrollToBottom = throttle(
(requiresNearBottom: boolean, delay: number = 80) => {
const mainScrollElem = document.getElementById('main-scroll');
if (!mainScrollElem) return;
const spaceToBottom =
mainScrollElem.scrollHeight -
mainScrollElem.scrollTop -
mainScrollElem.clientHeight;
if (!requiresNearBottom || spaceToBottom < 50) {
setTimeout(
() => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
delay
);
}
},
80
);
export default function ChatScreen() {
const {
viewingConversation,
viewingChat,
sendMessage,
isGenerating,
stopGenerating,
pendingMessages,
canvasData,
replaceMessageAndGenerate,
} = useAppContext();
const [inputMsg, setInputMsg] = useState('');
const containerRef = useRef<HTMLDivElement>(null);
const navigate = useNavigate();
const currConvId = viewingConversation?.id ?? '';
const pendingMsg: PendingMessage | undefined = pendingMessages[currConvId];
// keep track of leaf node for rendering
const [currNodeId, setCurrNodeId] = useState<number>(-1);
const messages: MessageDisplay[] = useMemo(() => {
if (!viewingChat) return [];
else return getListMessageDisplay(viewingChat.messages, currNodeId);
}, [currNodeId, viewingChat]);
const scrollToBottom = (requiresNearBottom: boolean) => {
if (!containerRef.current) return;
const msgListElem = containerRef.current;
const spaceToBottom =
msgListElem.scrollHeight -
msgListElem.scrollTop -
msgListElem.clientHeight;
if (!requiresNearBottom || spaceToBottom < 50) {
setTimeout(
() => msgListElem.scrollTo({ top: msgListElem.scrollHeight }),
1
);
const currConvId = viewingChat?.conv.id ?? null;
const pendingMsg: PendingMessage | undefined =
pendingMessages[currConvId ?? ''];
useEffect(() => {
// reset to latest node when conversation changes
setCurrNodeId(-1);
// scroll to bottom when conversation changes
scrollToBottom(false, 1);
}, [currConvId]);
const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
if (currLeafNodeId) {
setCurrNodeId(currLeafNodeId);
}
scrollToBottom(true);
};
// scroll to bottom when conversation changes
useEffect(() => {
scrollToBottom(false);
}, [viewingConversation?.id]);
const sendNewMessage = async () => {
if (inputMsg.trim().length === 0 || isGenerating(currConvId)) return;
const convId = viewingConversation?.id ?? StorageUtils.getNewConvId();
if (inputMsg.trim().length === 0 || isGenerating(currConvId ?? '')) return;
const lastInpMsg = inputMsg;
setInputMsg('');
if (!viewingConversation) {
// if user is creating a new conversation, redirect to the new conversation
navigate(`/chat/${convId}`);
}
scrollToBottom(false);
// auto scroll as message is being generated
const onChunk = () => scrollToBottom(true);
if (!(await sendMessage(convId, inputMsg, onChunk))) {
setCurrNodeId(-1);
// get the last message node
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
if (!(await sendMessage(currConvId, lastMsgNodeId, inputMsg, onChunk))) {
// restore the input message if failed
setInputMsg(lastInpMsg);
}
};
const handleEditMessage = async (msg: Message, content: string) => {
if (!viewingChat) return;
setCurrNodeId(msg.id);
scrollToBottom(false);
await replaceMessageAndGenerate(
viewingChat.conv.id,
msg.parent,
content,
onChunk
);
setCurrNodeId(-1);
scrollToBottom(false);
};
const handleRegenerateMessage = async (msg: Message) => {
if (!viewingChat) return;
setCurrNodeId(msg.parent);
scrollToBottom(false);
await replaceMessageAndGenerate(
viewingChat.conv.id,
msg.parent,
null,
onChunk
);
setCurrNodeId(-1);
scrollToBottom(false);
};
const hasCanvas = !!canvasData;
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
const pendingMsgDisplay: MessageDisplay[] =
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
? [
{
msg: pendingMsg,
siblingLeafNodeIds: [],
siblingCurrIdx: 0,
isPending: true,
},
]
: [];
return (
<>
{/* chat messages */}
<div
className={classNames({
'grid lg:gap-8 grow transition-[300ms]': true,
'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
'grid-cols-[1fr_0fr]': !hasCanvas,
})}
>
<div
id="messages-list"
className="flex flex-col grow overflow-y-auto"
ref={containerRef}
className={classNames({
'flex flex-col w-full max-w-[900px] mx-auto': true,
'hidden lg:flex': hasCanvas, // adapted for mobile
flex: !hasCanvas,
})}
>
<div className="mt-auto flex justify-center">
{/* placeholder to shift the message to the bottom */}
{viewingConversation ? '' : 'Send a message to start'}
{/* chat messages */}
<div id="messages-list" className="grow">
<div className="mt-auto flex justify-center">
{/* placeholder to shift the message to the bottom */}
{viewingChat ? '' : 'Send a message to start'}
</div>
{[...messages, ...pendingMsgDisplay].map((msg) => (
<ChatMessage
key={msg.msg.id}
msg={msg.msg}
siblingLeafNodeIds={msg.siblingLeafNodeIds}
siblingCurrIdx={msg.siblingCurrIdx}
onRegenerateMessage={handleRegenerateMessage}
onEditMessage={handleEditMessage}
onChangeSibling={setCurrNodeId}
/>
))}
</div>
{viewingConversation?.messages.map((msg) => (
<ChatMessage key={msg.id} msg={msg} scrollToBottom={scrollToBottom} />
))}
{pendingMsg && (
<ChatMessage
msg={pendingMsg}
scrollToBottom={scrollToBottom}
isPending
id="pending-msg"
/>
{/* chat input */}
<div className="flex flex-row items-center pt-8 pb-6 sticky bottom-0 bg-base-100">
<textarea
className="textarea textarea-bordered w-full"
placeholder="Type a message (Shift+Enter to add a new line)"
value={inputMsg}
onChange={(e) => setInputMsg(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter' && e.shiftKey) return;
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendNewMessage();
}
}}
id="msg-input"
dir="auto"
></textarea>
{isGenerating(currConvId ?? '') ? (
<button
className="btn btn-neutral ml-2"
onClick={() => stopGenerating(currConvId ?? '')}
>
Stop
</button>
) : (
<button
className="btn btn-primary ml-2"
onClick={sendNewMessage}
disabled={inputMsg.trim().length === 0}
>
Send
</button>
)}
</div>
</div>
<div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
{canvasData?.type === CanvasType.PY_INTERPRETER && (
<CanvasPyInterpreter />
)}
</div>
{/* chat input */}
<div className="flex flex-row items-center mt-8 mb-6">
<textarea
className="textarea textarea-bordered w-full"
placeholder="Type a message (Shift+Enter to add a new line)"
value={inputMsg}
onChange={(e) => setInputMsg(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter' && e.shiftKey) return;
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendNewMessage();
}
}}
id="msg-input"
dir="auto"
></textarea>
{isGenerating(currConvId) ? (
<button
className="btn btn-neutral ml-2"
onClick={() => stopGenerating(currConvId)}
>
Stop
</button>
) : (
<button
className="btn btn-primary ml-2"
onClick={sendNewMessage}
disabled={inputMsg.trim().length === 0}
>
Send
</button>
)}
</div>
</>
</div>
);
}
+44 -47
View File
@@ -5,12 +5,11 @@ import { classNames } from '../utils/misc';
import daisyuiThemes from 'daisyui/src/theming/themes';
import { THEMES } from '../Config';
import { useNavigate } from 'react-router';
import SettingDialog from './SettingDialog';
export default function Header() {
const navigate = useNavigate();
const [selectedTheme, setSelectedTheme] = useState(StorageUtils.getTheme());
const [showSettingDialog, setShowSettingDialog] = useState(false);
const { setShowSettings } = useAppContext();
const setTheme = (theme: string) => {
StorageUtils.setTheme(theme);
@@ -26,12 +25,12 @@ export default function Header() {
);
}, [selectedTheme]);
const { isGenerating, viewingConversation } = useAppContext();
const isCurrConvGenerating = isGenerating(viewingConversation?.id ?? '');
const { isGenerating, viewingChat } = useAppContext();
const isCurrConvGenerating = isGenerating(viewingChat?.conv.id ?? '');
const removeConversation = () => {
if (isCurrConvGenerating || !viewingConversation) return;
const convId = viewingConversation.id;
if (isCurrConvGenerating || !viewingChat) return;
const convId = viewingChat?.conv.id;
if (window.confirm('Are you sure to delete this conversation?')) {
StorageUtils.remove(convId);
navigate('/');
@@ -39,9 +38,9 @@ export default function Header() {
};
const downloadConversation = () => {
if (isCurrConvGenerating || !viewingConversation) return;
const convId = viewingConversation.id;
const conversationJson = JSON.stringify(viewingConversation, null, 2);
if (isCurrConvGenerating || !viewingChat) return;
const convId = viewingChat?.conv.id;
const conversationJson = JSON.stringify(viewingChat, null, 2);
const blob = new Blob([conversationJson], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
@@ -54,7 +53,7 @@ export default function Header() {
};
return (
<div className="flex flex-row items-center mt-6 mb-6">
<div className="flex flex-row items-center pt-6 pb-6 sticky top-0 z-10 bg-base-100">
{/* open sidebar button */}
<label htmlFor="toggle-drawer" className="btn btn-ghost lg:hidden">
<svg
@@ -76,40 +75,43 @@ export default function Header() {
{/* action buttons (top right) */}
<div className="flex items-center">
<div v-if="messages.length > 0" className="dropdown dropdown-end">
{/* "..." button */}
<button
tabIndex={0}
role="button"
className="btn m-1"
disabled={isCurrConvGenerating}
>
<svg
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
fill="currentColor"
className="bi bi-three-dots-vertical"
viewBox="0 0 16 16"
{viewingChat && (
<div className="dropdown dropdown-end">
{/* "..." button */}
<button
tabIndex={0}
role="button"
className="btn m-1"
disabled={isCurrConvGenerating}
>
<path d="M9.5 13a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0" />
</svg>
</button>
{/* dropdown menu */}
<ul
tabIndex={0}
className="dropdown-content menu bg-base-100 rounded-box z-[1] w-52 p-2 shadow"
>
<li onClick={downloadConversation}>
<a>Download</a>
</li>
<li className="text-error" onClick={removeConversation}>
<a>Delete</a>
</li>
</ul>
</div>
<svg
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
fill="currentColor"
className="bi bi-three-dots-vertical"
viewBox="0 0 16 16"
>
<path d="M9.5 13a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0" />
</svg>
</button>
{/* dropdown menu */}
<ul
tabIndex={0}
className="dropdown-content menu bg-base-100 rounded-box z-[1] w-52 p-2 shadow"
>
<li onClick={downloadConversation}>
<a>Download</a>
</li>
<li className="text-error" onClick={removeConversation}>
<a>Delete</a>
</li>
</ul>
</div>
)}
<div className="tooltip tooltip-bottom" data-tip="Settings">
<button className="btn" onClick={() => setShowSettingDialog(true)}>
<button className="btn" onClick={() => setShowSettings(true)}>
{/* settings button */}
<svg
xmlns="http://www.w3.org/2000/svg"
@@ -172,11 +174,6 @@ export default function Header() {
</div>
</div>
</div>
<SettingDialog
show={showSettingDialog}
onClose={() => setShowSettingDialog(false)}
/>
</div>
);
}
@@ -9,8 +9,16 @@ import 'katex/dist/katex.min.css';
import { classNames, copyStr } from '../utils/misc';
import { ElementContent, Root } from 'hast';
import { visit } from 'unist-util-visit';
import { useAppContext } from '../utils/app.context';
import { CanvasType } from '../utils/types';
export default function MarkdownDisplay({ content }: { content: string }) {
export default function MarkdownDisplay({
content,
isGenerating,
}: {
content: string;
isGenerating?: boolean;
}) {
const preprocessedContent = useMemo(
() => preprocessLaTeX(content),
[content]
@@ -21,8 +29,13 @@ export default function MarkdownDisplay({ content }: { content: string }) {
rehypePlugins={[rehypeHightlight, rehypeKatex, rehypeCustomCopyButton]}
components={{
button: (props) => (
<CopyCodeButton {...props} origContent={preprocessedContent} />
<CodeBlockButtons
{...props}
isGenerating={isGenerating}
origContent={preprocessedContent}
/>
),
// note: do not use "pre", "p" or other basic html elements here, it will cause the node to re-render when the message is being generated (this should be a bug with react-markdown, not sure how to fix it)
}}
>
{preprocessedContent}
@@ -30,11 +43,12 @@ export default function MarkdownDisplay({ content }: { content: string }) {
);
}
const CopyCodeButton: React.ElementType<
const CodeBlockButtons: React.ElementType<
React.ClassAttributes<HTMLButtonElement> &
React.HTMLAttributes<HTMLButtonElement> &
ExtraProps & { origContent: string }
> = ({ node, origContent }) => {
ExtraProps & { origContent: string; isGenerating?: boolean }
> = ({ node, origContent, isGenerating }) => {
const { config } = useAppContext();
const startOffset = node?.position?.start.offset ?? 0;
const endOffset = node?.position?.end.offset ?? 0;
@@ -47,14 +61,33 @@ const CopyCodeButton: React.ElementType<
[origContent, startOffset, endOffset]
);
const codeLanguage = useMemo(
() =>
origContent
.substring(startOffset, startOffset + 10)
.match(/^```([^\n]+)\n/)?.[1] ?? '',
[origContent, startOffset]
);
const canRunCode =
!isGenerating &&
config.pyIntepreterEnabled &&
codeLanguage.startsWith('py');
return (
<div
className={classNames({
'text-right sticky top-4 mb-2 mr-2 h-0': true,
'text-right sticky top-[7em] mb-2 mr-2 h-0': true,
'display-none': !node?.position,
})}
>
<CopyButton className="badge btn-mini" content={copiedContent} />
{canRunCode && (
<RunPyCodeButton
className="badge btn-mini ml-2"
content={copiedContent}
/>
)}
</div>
);
};
@@ -81,6 +114,31 @@ export const CopyButton = ({
);
};
export const RunPyCodeButton = ({
content,
className,
}: {
content: string;
className?: string;
}) => {
const { setCanvasData } = useAppContext();
return (
<>
<button
className={className}
onClick={() =>
setCanvasData({
type: CanvasType.PY_INTERPRETER,
content,
})
}
>
Run
</button>
</>
);
};
/**
* This injects the "button" element before each "pre" element.
* The actual button will be replaced with a react component in the MarkdownDisplay.
@@ -94,9 +152,7 @@ function rehypeCustomCopyButton() {
// replace current node
preNode.properties.visited = 'true';
node.tagName = 'div';
node.properties = {
className: 'relative my-4',
};
node.properties = {};
// add node for button
const btnNode: ElementContent = {
type: 'element',
@@ -3,17 +3,27 @@ import { useAppContext } from '../utils/app.context';
import { CONFIG_DEFAULT, CONFIG_INFO } from '../Config';
import { isDev } from '../Config';
import StorageUtils from '../utils/storage';
import { classNames, isBoolean, isNumeric, isString } from '../utils/misc';
import {
BeakerIcon,
ChatBubbleOvalLeftEllipsisIcon,
Cog6ToothIcon,
FunnelIcon,
HandRaisedIcon,
SquaresPlusIcon,
} from '@heroicons/react/24/outline';
import { OpenInNewTab } from '../utils/common';
type SettKey = keyof typeof CONFIG_DEFAULT;
const COMMON_SAMPLER_KEYS: SettKey[] = [
const BASIC_KEYS: SettKey[] = [
'temperature',
'top_k',
'top_p',
'min_p',
'max_tokens',
];
const OTHER_SAMPLER_KEYS: SettKey[] = [
const SAMPLER_KEYS: SettKey[] = [
'dynatemp_range',
'dynatemp_exponent',
'typical_p',
@@ -31,6 +41,223 @@ const PENALTY_KEYS: SettKey[] = [
'dry_penalty_last_n',
];
enum SettingInputType {
SHORT_INPUT,
LONG_INPUT,
CHECKBOX,
CUSTOM,
}
interface SettingFieldInput {
type: Exclude<SettingInputType, SettingInputType.CUSTOM>;
label: string | React.ReactElement;
help?: string | React.ReactElement;
key: SettKey;
}
interface SettingFieldCustom {
type: SettingInputType.CUSTOM;
key: SettKey;
component:
| string
| React.FC<{
value: string | boolean | number;
onChange: (value: string) => void;
}>;
}
interface SettingSection {
title: React.ReactElement;
fields: (SettingFieldInput | SettingFieldCustom)[];
}
const ICON_CLASSNAME = 'w-4 h-4 mr-1 inline';
const SETTING_SECTIONS: SettingSection[] = [
{
title: (
<>
<Cog6ToothIcon className={ICON_CLASSNAME} />
General
</>
),
fields: [
{
type: SettingInputType.SHORT_INPUT,
label: 'API Key',
key: 'apiKey',
},
{
type: SettingInputType.LONG_INPUT,
label: 'System Message (will be disabled if left empty)',
key: 'systemMessage',
},
...BASIC_KEYS.map(
(key) =>
({
type: SettingInputType.SHORT_INPUT,
label: key,
key,
}) as SettingFieldInput
),
],
},
{
title: (
<>
<FunnelIcon className={ICON_CLASSNAME} />
Samplers
</>
),
fields: [
{
type: SettingInputType.SHORT_INPUT,
label: 'Samplers queue',
key: 'samplers',
},
...SAMPLER_KEYS.map(
(key) =>
({
type: SettingInputType.SHORT_INPUT,
label: key,
key,
}) as SettingFieldInput
),
],
},
{
title: (
<>
<HandRaisedIcon className={ICON_CLASSNAME} />
Penalties
</>
),
fields: PENALTY_KEYS.map((key) => ({
type: SettingInputType.SHORT_INPUT,
label: key,
key,
})),
},
{
title: (
<>
<ChatBubbleOvalLeftEllipsisIcon className={ICON_CLASSNAME} />
Reasoning
</>
),
fields: [
{
type: SettingInputType.CHECKBOX,
label: 'Expand though process by default for generating message',
key: 'showThoughtInProgress',
},
{
type: SettingInputType.CHECKBOX,
label:
'Exclude thought process when sending request to API (Recommended for DeepSeek-R1)',
key: 'excludeThoughtOnReq',
},
],
},
{
title: (
<>
<SquaresPlusIcon className={ICON_CLASSNAME} />
Advanced
</>
),
fields: [
{
type: SettingInputType.CUSTOM,
key: 'custom', // dummy key, won't be used
component: () => {
const debugImportDemoConv = async () => {
const res = await fetch('/demo-conversation.json');
const demoConv = await res.json();
StorageUtils.remove(demoConv.id);
for (const msg of demoConv.messages) {
StorageUtils.appendMsg(demoConv.id, msg);
}
};
return (
<button className="btn" onClick={debugImportDemoConv}>
(debug) Import demo conversation
</button>
);
},
},
{
type: SettingInputType.CHECKBOX,
label: 'Show tokens per second',
key: 'showTokensPerSecond',
},
{
type: SettingInputType.LONG_INPUT,
label: (
<>
Custom JSON config (For more info, refer to{' '}
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md">
server documentation
</OpenInNewTab>
)
</>
),
key: 'custom',
},
],
},
{
title: (
<>
<BeakerIcon className={ICON_CLASSNAME} />
Experimental
</>
),
fields: [
{
type: SettingInputType.CUSTOM,
key: 'custom', // dummy key, won't be used
component: () => (
<>
<p className="mb-8">
Experimental features are not guaranteed to work correctly.
<br />
<br />
If you encounter any problems, create a{' '}
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/new?template=019-bug-misc.yml">
Bug (misc.)
</OpenInNewTab>{' '}
report on Github. Please also specify <b>webui/experimental</b> on
the report title and include screenshots.
<br />
<br />
Some features may require packages downloaded from CDN, so they
need internet connection.
</p>
</>
),
},
{
type: SettingInputType.CHECKBOX,
label: (
<>
<b>Enable Python interpreter</b>
<br />
<small className="text-xs">
This feature uses{' '}
<OpenInNewTab href="https://pyodide.org">pyodide</OpenInNewTab>,
downloaded from CDN. To use this feature, ask the LLM to generate
python code inside a markdown code block. You will see a "Run"
button on the code block, near the "Copy" button.
</small>
</>
),
key: 'pyIntepreterEnabled',
},
],
},
];
export default function SettingDialog({
show,
onClose,
@@ -39,6 +266,7 @@ export default function SettingDialog({
onClose: () => void;
}) {
const { config, saveConfig } = useAppContext();
const [sectionIdx, setSectionIdx] = useState(0);
// clone the config object to prevent direct mutation
const [localConfig, setLocalConfig] = useState<typeof CONFIG_DEFAULT>(
@@ -52,206 +280,148 @@ export default function SettingDialog({
};
const handleSave = () => {
saveConfig(localConfig);
// copy the local config to prevent direct mutation
const newConfig: typeof CONFIG_DEFAULT = JSON.parse(
JSON.stringify(localConfig)
);
// validate the config
for (const key in newConfig) {
const value = newConfig[key as SettKey];
const mustBeBoolean = isBoolean(CONFIG_DEFAULT[key as SettKey]);
const mustBeString = isString(CONFIG_DEFAULT[key as SettKey]);
const mustBeNumeric = isNumeric(CONFIG_DEFAULT[key as SettKey]);
if (mustBeString) {
if (!isString(value)) {
alert(`Value for ${key} must be string`);
return;
}
} else if (mustBeNumeric) {
const trimedValue = value.toString().trim();
const numVal = Number(trimedValue);
if (isNaN(numVal) || !isNumeric(numVal) || trimedValue.length === 0) {
alert(`Value for ${key} must be numeric`);
return;
}
// force conversion to number
// @ts-expect-error this is safe
newConfig[key] = numVal;
} else if (mustBeBoolean) {
if (!isBoolean(value)) {
alert(`Value for ${key} must be boolean`);
return;
}
} else {
console.error(`Unknown default type for key ${key}`);
}
}
if (isDev) console.log('Saving config', newConfig);
saveConfig(newConfig);
onClose();
};
const debugImportDemoConv = async () => {
const res = await fetch('/demo-conversation.json');
const demoConv = await res.json();
StorageUtils.remove(demoConv.id);
for (const msg of demoConv.messages) {
StorageUtils.appendMsg(demoConv.id, msg);
}
onClose();
const onChange = (key: SettKey) => (value: string | boolean) => {
// note: we do not perform validation here, because we may get incomplete value as user is still typing it
setLocalConfig({ ...localConfig, [key]: value });
};
return (
<dialog className={`modal ${show ? 'modal-open' : ''}`}>
<div className="modal-box">
<dialog className={classNames({ modal: true, 'modal-open': show })}>
<div className="modal-box w-11/12 max-w-3xl">
<h3 className="text-lg font-bold mb-6">Settings</h3>
<div className="h-[calc(90vh-12rem)] overflow-y-auto">
<p className="opacity-40 mb-6">
Settings below are saved in browser's localStorage
</p>
<SettingsModalShortInput
configKey="apiKey"
configDefault={CONFIG_DEFAULT}
value={localConfig.apiKey}
onChange={(value) =>
setLocalConfig({ ...localConfig, apiKey: value })
}
/>
<label className="form-control mb-2">
<div className="label">
System Message (will be disabled if left empty)
</div>
<textarea
className="textarea textarea-bordered h-24"
placeholder={`Default: ${CONFIG_DEFAULT.systemMessage}`}
value={localConfig.systemMessage}
onChange={(e) =>
setLocalConfig({
...localConfig,
systemMessage: e.target.value,
})
}
/>
</label>
{COMMON_SAMPLER_KEYS.map((key) => (
<SettingsModalShortInput
key={key}
configKey={key}
configDefault={CONFIG_DEFAULT}
value={localConfig[key]}
onChange={(value) =>
setLocalConfig({ ...localConfig, [key]: value })
}
/>
))}
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary className="collapse-title font-bold">
Other sampler settings
</summary>
<div className="collapse-content">
<SettingsModalShortInput
label="Samplers queue"
configKey="samplers"
configDefault={CONFIG_DEFAULT}
value={localConfig.samplers}
onChange={(value) =>
setLocalConfig({ ...localConfig, samplers: value })
}
/>
{OTHER_SAMPLER_KEYS.map((key) => (
<SettingsModalShortInput
key={key}
configKey={key}
configDefault={CONFIG_DEFAULT}
value={localConfig[key]}
onChange={(value) =>
setLocalConfig({ ...localConfig, [key]: value })
}
/>
))}
</div>
</details>
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary className="collapse-title font-bold">
Penalties settings
</summary>
<div className="collapse-content">
{PENALTY_KEYS.map((key) => (
<SettingsModalShortInput
key={key}
configKey={key}
configDefault={CONFIG_DEFAULT}
value={localConfig[key]}
onChange={(value) =>
setLocalConfig({ ...localConfig, [key]: value })
}
/>
))}
</div>
</details>
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary className="collapse-title font-bold">
Reasoning models
</summary>
<div className="collapse-content">
<div className="flex flex-row items-center mb-2">
<input
type="checkbox"
className="checkbox"
checked={localConfig.showThoughtInProgress}
onChange={(e) =>
setLocalConfig({
...localConfig,
showThoughtInProgress: e.target.checked,
})
}
/>
<span className="ml-4">
Expand though process by default for generating message
</span>
<div className="flex flex-col md:flex-row h-[calc(90vh-12rem)]">
{/* Left panel, showing sections - Desktop version */}
<div className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200">
{SETTING_SECTIONS.map((section, idx) => (
<div
key={idx}
className={classNames({
'btn btn-ghost justify-start font-normal w-44 mb-1': true,
'btn-active': sectionIdx === idx,
})}
onClick={() => setSectionIdx(idx)}
dir="auto"
>
{section.title}
</div>
<div className="flex flex-row items-center mb-2">
<input
type="checkbox"
className="checkbox"
checked={localConfig.excludeThoughtOnReq}
onChange={(e) =>
setLocalConfig({
...localConfig,
excludeThoughtOnReq: e.target.checked,
})
}
/>
<span className="ml-4">
Exclude thought process when sending request to API
(Recommended for DeepSeek-R1)
</span>
</div>
</div>
</details>
))}
</div>
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary className="collapse-title font-bold">
Advanced config
</summary>
<div className="collapse-content">
{/* this button only shows in dev mode, used to import a demo conversation to test message rendering */}
{isDev && (
<div className="flex flex-row items-center mb-2">
<button className="btn" onClick={debugImportDemoConv}>
(debug) Import demo conversation
</button>
</div>
)}
<div className="flex flex-row items-center mb-2">
<input
type="checkbox"
className="checkbox"
checked={localConfig.showTokensPerSecond}
onChange={(e) =>
setLocalConfig({
...localConfig,
showTokensPerSecond: e.target.checked,
})
}
/>
<span className="ml-4">Show tokens per second</span>
</div>
<label className="form-control mb-2">
<div className="label inline">
Custom JSON config (For more info, refer to{' '}
<a
className="underline"
href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md"
target="_blank"
rel="noopener noreferrer"
{/* Left panel, showing sections - Mobile version */}
<div className="md:hidden flex flex-row gap-2 mb-4">
<details className="dropdown">
<summary className="btn bt-sm w-full m-1">
{SETTING_SECTIONS[sectionIdx].title}
</summary>
<ul className="menu dropdown-content bg-base-100 rounded-box z-[1] w-52 p-2 shadow">
{SETTING_SECTIONS.map((section, idx) => (
<div
key={idx}
className={classNames({
'btn btn-ghost justify-start font-normal': true,
'btn-active': sectionIdx === idx,
})}
onClick={() => setSectionIdx(idx)}
dir="auto"
>
server documentation
</a>
)
</div>
<textarea
className="textarea textarea-bordered h-24"
placeholder='Example: { "mirostat": 1, "min_p": 0.1 }'
value={localConfig.custom}
onChange={(e) =>
setLocalConfig({ ...localConfig, custom: e.target.value })
}
/>
</label>
</div>
</details>
{section.title}
</div>
))}
</ul>
</details>
</div>
{/* Right panel, showing setting fields */}
<div className="grow overflow-y-auto px-4">
{SETTING_SECTIONS[sectionIdx].fields.map((field, idx) => {
const key = `${sectionIdx}-${idx}`;
if (field.type === SettingInputType.SHORT_INPUT) {
return (
<SettingsModalShortInput
key={key}
configKey={field.key}
value={localConfig[field.key]}
onChange={onChange(field.key)}
label={field.label as string}
/>
);
} else if (field.type === SettingInputType.LONG_INPUT) {
return (
<SettingsModalLongInput
key={key}
configKey={field.key}
value={localConfig[field.key].toString()}
onChange={onChange(field.key)}
label={field.label as string}
/>
);
} else if (field.type === SettingInputType.CHECKBOX) {
return (
<SettingsModalCheckbox
key={key}
configKey={field.key}
value={!!localConfig[field.key]}
onChange={onChange(field.key)}
label={field.label as string}
/>
);
} else if (field.type === SettingInputType.CUSTOM) {
return (
<div key={key} className="mb-2">
{typeof field.component === 'string'
? field.component
: field.component({
value: localConfig[field.key],
onChange: onChange(field.key),
})}
</div>
);
}
})}
<p className="opacity-40 mb-6 text-sm mt-8">
Settings are saved in browser's localStorage
</p>
</div>
</div>
<div className="modal-action">
@@ -270,37 +440,97 @@ export default function SettingDialog({
);
}
function SettingsModalShortInput({
function SettingsModalLongInput({
configKey,
configDefault,
value,
onChange,
label,
}: {
configKey: SettKey;
configDefault: typeof CONFIG_DEFAULT;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
value: any;
value: string;
onChange: (value: string) => void;
label?: string;
}) {
return (
<label className="input input-bordered join-item grow flex items-center gap-2 mb-2">
<div className="dropdown dropdown-hover">
<div tabIndex={0} role="button" className="font-bold">
{label || configKey}
</div>
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
{CONFIG_INFO[configKey] ?? '(no help message available)'}
</div>
</div>
<input
type="text"
className="grow"
placeholder={`Default: ${configDefault[configKey] || 'none'}`}
<label className="form-control mb-2">
<div className="label inline">{label || configKey}</div>
<textarea
className="textarea textarea-bordered h-24"
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
value={value}
onChange={(e) => onChange(e.target.value)}
/>
</label>
);
}
function SettingsModalShortInput({
configKey,
value,
onChange,
label,
}: {
configKey: SettKey;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
value: any;
onChange: (value: string) => void;
label?: string;
}) {
const helpMsg = CONFIG_INFO[configKey];
return (
<>
{/* on mobile, we simply show the help message here */}
{helpMsg && (
<div className="block md:hidden mb-1">
<b>{label || configKey}</b>
<br />
<p className="text-xs">{helpMsg}</p>
</div>
)}
<label className="input input-bordered join-item grow flex items-center gap-2 mb-2">
<div className="dropdown dropdown-hover">
<div tabIndex={0} role="button" className="font-bold hidden md:block">
{label || configKey}
</div>
{helpMsg && (
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
{helpMsg}
</div>
)}
</div>
<input
type="text"
className="grow"
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
value={value}
onChange={(e) => onChange(e.target.value)}
/>
</label>
</>
);
}
function SettingsModalCheckbox({
configKey,
value,
onChange,
label,
}: {
configKey: SettKey;
value: boolean;
onChange: (value: boolean) => void;
label: string;
}) {
return (
<div className="flex flex-row items-center mb-2">
<input
type="checkbox"
className="toggle"
checked={value}
onChange={(e) => onChange(e.target.checked)}
/>
<span className="ml-4">{label || configKey}</span>
</div>
);
}
@@ -1,4 +1,4 @@
import { useEffect, useMemo, useState } from 'react';
import { useEffect, useState } from 'react';
import { classNames } from '../utils/misc';
import { Conversation } from '../utils/types';
import StorageUtils from '../utils/storage';
@@ -7,16 +7,17 @@ import { useNavigate, useParams } from 'react-router';
export default function Sidebar() {
const params = useParams();
const navigate = useNavigate();
const currConv = useMemo(
() => StorageUtils.getOneConversation(params.convId ?? ''),
[params.convId]
);
const [conversations, setConversations] = useState<Conversation[]>([]);
const [currConv, setCurrConv] = useState<Conversation | null>(null);
useEffect(() => {
const handleConversationChange = () => {
setConversations(StorageUtils.getAllConversations());
StorageUtils.getOneConversation(params.convId ?? '').then(setCurrConv);
}, [params.convId]);
useEffect(() => {
const handleConversationChange = async () => {
setConversations(await StorageUtils.getAllConversations());
};
StorageUtils.onConversationChanged(handleConversationChange);
handleConversationChange();
@@ -82,11 +83,11 @@ export default function Sidebar() {
onClick={() => navigate(`/chat/${conv.id}`)}
dir="auto"
>
<span className="truncate">{conv.messages[0].content}</span>
<span className="truncate">{conv.name}</span>
</div>
))}
<div className="text-center text-xs opacity-40 mt-auto mx-4">
Conversations are saved to browser's localStorage
Conversations are saved to browser's IndexedDB
</div>
</div>
</div>
+6
View File
@@ -45,6 +45,9 @@
/* Highlight.js */
[data-color-scheme='light'] {
@include meta.load-css('highlight.js/styles/stackoverflow-light');
.dark-color {
@apply bg-base-content text-base-100;
}
}
[data-color-scheme='dark'] {
@include meta.load-css('highlight.js/styles/stackoverflow-dark');
@@ -52,6 +55,9 @@
[data-color-scheme='auto'] {
@media (prefers-color-scheme: light) {
@include meta.load-css('highlight.js/styles/stackoverflow-light');
.dark-color {
@apply bg-base-content text-base-100;
}
}
@media (prefers-color-scheme: dark) {
@include meta.load-css('highlight.js/styles/stackoverflow-dark');
+130 -53
View File
@@ -1,5 +1,12 @@
import React, { createContext, useContext, useEffect, useState } from 'react';
import { APIMessage, Conversation, Message, PendingMessage } from './types';
import {
APIMessage,
CanvasData,
Conversation,
Message,
PendingMessage,
ViewingChat,
} from './types';
import StorageUtils from './storage';
import {
filterThoughtFromMsgs,
@@ -7,46 +14,65 @@ import {
getSSEStreamAsync,
} from './misc';
import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
import { matchPath, useLocation } from 'react-router';
import { matchPath, useLocation, useNavigate } from 'react-router';
interface AppContextValue {
viewingConversation: Conversation | null;
// conversations and messages
viewingChat: ViewingChat | null;
pendingMessages: Record<Conversation['id'], PendingMessage>;
isGenerating: (convId: string) => boolean;
sendMessage: (
convId: string,
convId: string | null,
leafNodeId: Message['id'] | null,
content: string,
onChunk?: CallbackGeneratedChunk
onChunk: CallbackGeneratedChunk
) => Promise<boolean>;
stopGenerating: (convId: string) => void;
replaceMessageAndGenerate: (
convId: string,
origMsgId: Message['id'],
content?: string,
onChunk?: CallbackGeneratedChunk
parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null,
onChunk: CallbackGeneratedChunk
) => Promise<void>;
// canvas
canvasData: CanvasData | null;
setCanvasData: (data: CanvasData | null) => void;
// config
config: typeof CONFIG_DEFAULT;
saveConfig: (config: typeof CONFIG_DEFAULT) => void;
showSettings: boolean;
setShowSettings: (show: boolean) => void;
}
// for now, this callback is only used for scrolling to the bottom of the chat
type CallbackGeneratedChunk = () => void;
// this callback is used for scrolling to the bottom of the chat and switching to the last node
export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const AppContext = createContext<AppContextValue>({} as any);
const getViewingChat = async (convId: string): Promise<ViewingChat | null> => {
const conv = await StorageUtils.getOneConversation(convId);
if (!conv) return null;
return {
conv: conv,
// all messages from all branches, not filtered by last node
messages: await StorageUtils.getMessages(convId),
};
};
export const AppContextProvider = ({
children,
}: {
children: React.ReactElement;
}) => {
const { pathname } = useLocation();
const navigate = useNavigate();
const params = matchPath('/chat/:convId', pathname);
const convId = params?.params?.convId;
const [viewingConversation, setViewingConversation] =
useState<Conversation | null>(null);
const [viewingChat, setViewingChat] = useState<ViewingChat | null>(null);
const [pendingMessages, setPendingMessages] = useState<
Record<Conversation['id'], PendingMessage>
>({});
@@ -54,14 +80,19 @@ export const AppContextProvider = ({
Record<Conversation['id'], AbortController>
>({});
const [config, setConfig] = useState(StorageUtils.getConfig());
const [canvasData, setCanvasData] = useState<CanvasData | null>(null);
const [showSettings, setShowSettings] = useState(false);
// handle change when the convId from URL is changed
useEffect(() => {
const handleConversationChange = (changedConvId: string) => {
// also reset the canvas data
setCanvasData(null);
const handleConversationChange = async (changedConvId: string) => {
if (changedConvId !== convId) return;
setViewingConversation(StorageUtils.getOneConversation(convId));
setViewingChat(await getViewingChat(changedConvId));
};
StorageUtils.onConversationChanged(handleConversationChange);
setViewingConversation(StorageUtils.getOneConversation(convId ?? ''));
getViewingChat(convId ?? '').then(setViewingChat);
return () => {
StorageUtils.offConversationChanged(handleConversationChange);
};
@@ -99,23 +130,39 @@ export const AppContextProvider = ({
const generateMessage = async (
convId: string,
onChunk?: CallbackGeneratedChunk
leafNodeId: Message['id'],
onChunk: CallbackGeneratedChunk
) => {
if (isGenerating(convId)) return;
const config = StorageUtils.getConfig();
const currConversation = StorageUtils.getOneConversation(convId);
const currConversation = await StorageUtils.getOneConversation(convId);
if (!currConversation) {
throw new Error('Current conversation is not found');
}
const currMessages = StorageUtils.filterByLeafNodeId(
await StorageUtils.getMessages(convId),
leafNodeId,
false
);
const abortController = new AbortController();
setAbort(convId, abortController);
if (!currMessages) {
throw new Error('Current messages are not found');
}
const pendingId = Date.now() + 1;
let pendingMsg: PendingMessage = {
id: Date.now() + 1,
id: pendingId,
convId,
type: 'text',
timestamp: pendingId,
role: 'assistant',
content: null,
parent: leafNodeId,
children: [],
};
setPending(convId, pendingMsg);
@@ -125,7 +172,7 @@ export const AppContextProvider = ({
...(config.systemMessage.length === 0
? []
: [{ role: 'system', content: config.systemMessage } as APIMessage]),
...normalizeMsgsForAPI(currConversation?.messages ?? []),
...normalizeMsgsForAPI(currMessages),
];
if (config.excludeThoughtOnReq) {
messages = filterThoughtFromMsgs(messages);
@@ -186,8 +233,7 @@ export const AppContextProvider = ({
const lastContent = pendingMsg.content || '';
if (addedContent) {
pendingMsg = {
id: pendingMsg.id,
role: 'assistant',
...pendingMsg,
content: lastContent + addedContent,
};
}
@@ -202,7 +248,7 @@ export const AppContextProvider = ({
};
}
setPending(convId, pendingMsg);
onChunk?.();
onChunk(); // don't need to switch node for pending message
}
} catch (err) {
setPending(convId, null);
@@ -217,37 +263,53 @@ export const AppContextProvider = ({
}
}
if (pendingMsg.content) {
StorageUtils.appendMsg(currConversation.id, {
id: pendingMsg.id,
content: pendingMsg.content,
role: pendingMsg.role,
timings: pendingMsg.timings,
});
if (pendingMsg.content !== null) {
await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId);
}
setPending(convId, null);
onChunk?.(); // trigger scroll to bottom
onChunk(pendingId); // trigger scroll to bottom and switch to the last node
};
const sendMessage = async (
convId: string,
convId: string | null,
leafNodeId: Message['id'] | null,
content: string,
onChunk?: CallbackGeneratedChunk
onChunk: CallbackGeneratedChunk
): Promise<boolean> => {
if (isGenerating(convId) || content.trim().length === 0) return false;
if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
StorageUtils.appendMsg(convId, {
id: Date.now(),
role: 'user',
content,
});
if (convId === null || convId.length === 0 || leafNodeId === null) {
const conv = await StorageUtils.createConversation(
content.substring(0, 256)
);
convId = conv.id;
leafNodeId = conv.currNode;
// if user is creating a new conversation, redirect to the new conversation
navigate(`/chat/${convId}`);
}
const now = Date.now();
const currMsgId = now;
StorageUtils.appendMsg(
{
id: currMsgId,
timestamp: now,
type: 'text',
convId,
role: 'user',
content,
parent: leafNodeId,
children: [],
},
leafNodeId
);
onChunk(currMsgId);
try {
await generateMessage(convId, onChunk);
await generateMessage(convId, currMsgId, onChunk);
return true;
} catch (_) {
// rollback
StorageUtils.popMsg(convId);
// TODO: rollback
}
return false;
};
@@ -260,22 +322,33 @@ export const AppContextProvider = ({
// if content is undefined, we remove last assistant message
const replaceMessageAndGenerate = async (
convId: string,
origMsgId: Message['id'],
content?: string,
onChunk?: CallbackGeneratedChunk
parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null,
onChunk: CallbackGeneratedChunk
) => {
if (isGenerating(convId)) return;
StorageUtils.filterAndKeepMsgs(convId, (msg) => msg.id < origMsgId);
if (content) {
StorageUtils.appendMsg(convId, {
id: Date.now(),
role: 'user',
content,
});
if (content !== null) {
const now = Date.now();
const currMsgId = now;
StorageUtils.appendMsg(
{
id: currMsgId,
timestamp: now,
type: 'text',
convId,
role: 'user',
content,
parent: parentNodeId,
children: [],
},
parentNodeId
);
parentNodeId = currMsgId;
}
onChunk(parentNodeId);
await generateMessage(convId, onChunk);
await generateMessage(convId, parentNodeId, onChunk);
};
const saveConfig = (config: typeof CONFIG_DEFAULT) => {
@@ -287,13 +360,17 @@ export const AppContextProvider = ({
<AppContext.Provider
value={{
isGenerating,
viewingConversation,
viewingChat,
pendingMessages,
sendMessage,
stopGenerating,
replaceMessageAndGenerate,
canvasData,
setCanvasData,
config,
saveConfig,
showSettings,
setShowSettings,
}}
>
{children}
@@ -0,0 +1,38 @@
export const XCloseButton: React.ElementType<
React.ClassAttributes<HTMLButtonElement> &
React.HTMLAttributes<HTMLButtonElement>
> = ({ className, ...props }) => (
<button className={`btn btn-square btn-sm ${className ?? ''}`} {...props}>
<svg
xmlns="http://www.w3.org/2000/svg"
className="h-6 w-6"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
);
export const OpenInNewTab = ({
href,
children,
}: {
href: string;
children: string;
}) => (
<a
className="underline"
href={href}
target="_blank"
rel="noopener noreferrer"
>
{children}
</a>
);
+25 -3
View File
@@ -4,7 +4,6 @@ import { APIMessage, Message } from './types';
// ponyfill for missing ReadableStream asyncIterator on Safari
import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator';
import { isDev } from '../Config';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const isString = (x: any) => !!x.toLowerCase;
@@ -23,7 +22,7 @@ export async function* getSSEStreamAsync(fetchResponse: Response) {
.pipeThrough(new TextLineStream());
// @ts-expect-error asyncIterator complains about type, but it should work
for await (const line of asyncIterator(lines)) {
if (isDev) console.log({ line });
//if (isDev) console.log({ line });
if (line.startsWith('data:') && !line.endsWith('[DONE]')) {
const data = JSON.parse(line.slice(5));
yield data;
@@ -55,7 +54,7 @@ export const copyStr = (textToCopy: string) => {
/**
* filter out redundant fields upon sending to API
*/
export function normalizeMsgsForAPI(messages: Message[]) {
export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
return messages.map((msg) => {
return {
role: msg.role,
@@ -85,3 +84,26 @@ export function classNames(classes: Record<string, boolean>): string {
.map(([key, _]) => key)
.join(' ');
}
export const delay = (ms: number) =>
new Promise((resolve) => setTimeout(resolve, ms));
export const throttle = <T extends unknown[]>(
callback: (...args: T) => void,
delay: number
) => {
let isWaiting = false;
return (...args: T) => {
if (isWaiting) {
return;
}
callback(...args);
isWaiting = true;
setTimeout(() => {
isWaiting = false;
}, delay);
};
};
+204 -58
View File
@@ -2,7 +2,8 @@
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
import { CONFIG_DEFAULT } from '../Config';
import { Conversation, Message } from './types';
import { Conversation, Message, TimingReport } from './types';
import Dexie, { Table } from 'dexie';
const event = new EventTarget();
@@ -17,85 +18,154 @@ const dispatchConversationChange = (convId: string) => {
);
};
const db = new Dexie('LlamacppWebui') as Dexie & {
conversations: Table<Conversation>;
messages: Table<Message>;
};
// https://dexie.org/docs/Version/Version.stores()
db.version(1).stores({
// Unlike SQL, you dont need to specify all properties but only the one you wish to index.
conversations: '&id, lastModified',
messages: '&id, convId, [convId+id], timestamp',
});
// convId is a string prefixed with 'conv-'
const StorageUtils = {
/**
* manage conversations
*/
getAllConversations(): Conversation[] {
const res = [];
for (const key in localStorage) {
if (key.startsWith('conv-')) {
res.push(JSON.parse(localStorage.getItem(key) ?? '{}'));
}
}
res.sort((a, b) => b.lastModified - a.lastModified);
return res;
async getAllConversations(): Promise<Conversation[]> {
await migrationLStoIDB().catch(console.error); // noop if already migrated
return (await db.conversations.toArray()).sort(
(a, b) => b.lastModified - a.lastModified
);
},
/**
* can return null if convId does not exist
*/
getOneConversation(convId: string): Conversation | null {
return JSON.parse(localStorage.getItem(convId) || 'null');
async getOneConversation(convId: string): Promise<Conversation | null> {
return (await db.conversations.where('id').equals(convId).first()) ?? null;
},
/**
* if convId does not exist, create one
* get all message nodes in a conversation
*/
appendMsg(convId: string, msg: Message): void {
if (msg.content === null) return;
const conv = StorageUtils.getOneConversation(convId) || {
id: convId,
lastModified: Date.now(),
messages: [],
async getMessages(convId: string): Promise<Message[]> {
return await db.messages.where({ convId }).toArray();
},
/**
* use in conjunction with getMessages to filter messages by leafNodeId
* includeRoot: whether to include the root node in the result
* if node with leafNodeId does not exist, return the path with the latest timestamp
*/
filterByLeafNodeId(
msgs: Readonly<Message[]>,
leafNodeId: Message['id'],
includeRoot: boolean
): Readonly<Message[]> {
const res: Message[] = [];
const nodeMap = new Map<Message['id'], Message>();
for (const msg of msgs) {
nodeMap.set(msg.id, msg);
}
let startNode: Message | undefined = nodeMap.get(leafNodeId);
if (!startNode) {
// if not found, we return the path with the latest timestamp
let latestTime = -1;
for (const msg of msgs) {
if (msg.timestamp > latestTime) {
startNode = msg;
latestTime = msg.timestamp;
}
}
}
// traverse the path from leafNodeId to root
// startNode can never be undefined here
let currNode: Message | undefined = startNode;
while (currNode) {
if (currNode.type !== 'root' || (currNode.type === 'root' && includeRoot))
res.push(currNode);
currNode = nodeMap.get(currNode.parent ?? -1);
}
res.sort((a, b) => a.timestamp - b.timestamp);
return res;
},
/**
* create a new conversation with a default root node
*/
async createConversation(name: string): Promise<Conversation> {
const now = Date.now();
const msgId = now;
const conv: Conversation = {
id: `conv-${now}`,
lastModified: now,
currNode: msgId,
name,
};
conv.messages.push(msg);
conv.lastModified = Date.now();
localStorage.setItem(convId, JSON.stringify(conv));
dispatchConversationChange(convId);
await db.conversations.add(conv);
// create a root node
await db.messages.add({
id: msgId,
convId: conv.id,
type: 'root',
timestamp: now,
role: 'system',
content: '',
parent: -1,
children: [],
});
return conv;
},
/**
* Get new conversation id
* if convId does not exist, throw an error
*/
getNewConvId(): string {
return `conv-${Date.now()}`;
async appendMsg(
msg: Exclude<Message, 'parent' | 'children'>,
parentNodeId: Message['id']
): Promise<void> {
if (msg.content === null) return;
const { convId } = msg;
await db.transaction('rw', db.conversations, db.messages, async () => {
const conv = await StorageUtils.getOneConversation(convId);
const parentMsg = await db.messages
.where({ convId, id: parentNodeId })
.first();
// update the currNode of conversation
if (!conv) {
throw new Error(`Conversation ${convId} does not exist`);
}
if (!parentMsg) {
throw new Error(
`Parent message ID ${parentNodeId} does not exist in conversation ${convId}`
);
}
await db.conversations.update(convId, {
lastModified: Date.now(),
currNode: msg.id,
});
// update parent
await db.messages.update(parentNodeId, {
children: [...parentMsg.children, msg.id],
});
// create message
await db.messages.add({
...msg,
parent: parentNodeId,
children: [],
});
});
dispatchConversationChange(convId);
},
/**
* remove conversation by id
*/
remove(convId: string): void {
localStorage.removeItem(convId);
async remove(convId: string): Promise<void> {
await db.transaction('rw', db.conversations, db.messages, async () => {
await db.conversations.delete(convId);
await db.messages.where({ convId }).delete();
});
dispatchConversationChange(convId);
},
/**
* remove all conversations
*/
filterAndKeepMsgs(
convId: string,
predicate: (msg: Message) => boolean
): void {
const conv = StorageUtils.getOneConversation(convId);
if (!conv) return;
conv.messages = conv.messages.filter(predicate);
conv.lastModified = Date.now();
localStorage.setItem(convId, JSON.stringify(conv));
dispatchConversationChange(convId);
},
/**
* remove last message from conversation
*/
popMsg(convId: string): Message | undefined {
const conv = StorageUtils.getOneConversation(convId);
if (!conv) return;
const msg = conv.messages.pop();
conv.lastModified = Date.now();
if (conv.messages.length === 0) {
StorageUtils.remove(convId);
} else {
localStorage.setItem(convId, JSON.stringify(conv));
}
dispatchConversationChange(convId);
return msg;
},
// event listeners
onConversationChanged(callback: CallbackConversationChanged) {
@@ -136,3 +206,79 @@ const StorageUtils = {
};
export default StorageUtils;
// Migration from localStorage to IndexedDB
// these are old types, LS prefix stands for LocalStorage
interface LSConversation {
id: string; // format: `conv-{timestamp}`
lastModified: number; // timestamp from Date.now()
messages: LSMessage[];
}
interface LSMessage {
id: number;
role: 'user' | 'assistant' | 'system';
content: string;
timings?: TimingReport;
}
async function migrationLStoIDB() {
if (localStorage.getItem('migratedToIDB')) return;
const res: LSConversation[] = [];
for (const key in localStorage) {
if (key.startsWith('conv-')) {
res.push(JSON.parse(localStorage.getItem(key) ?? '{}'));
}
}
if (res.length === 0) return;
await db.transaction('rw', db.conversations, db.messages, async () => {
let migratedCount = 0;
for (const conv of res) {
const { id: convId, lastModified, messages } = conv;
const firstMsg = messages[0];
const lastMsg = messages.at(-1);
if (messages.length < 2 || !firstMsg || !lastMsg) {
console.log(
`Skipping conversation ${convId} with ${messages.length} messages`
);
continue;
}
const name = firstMsg.content ?? '(no messages)';
await db.conversations.add({
id: convId,
lastModified,
currNode: lastMsg.id,
name,
});
const rootId = messages[0].id - 2;
await db.messages.add({
id: rootId,
convId: convId,
type: 'root',
timestamp: rootId,
role: 'system',
content: '',
parent: -1,
children: [firstMsg.id],
});
for (let i = 0; i < messages.length; i++) {
const msg = messages[i];
await db.messages.add({
...msg,
type: 'text',
convId: convId,
timestamp: msg.id,
parent: i === 0 ? rootId : messages[i - 1].id,
children: i === messages.length - 1 ? [] : [messages[i + 1].id],
});
}
migratedCount++;
console.log(
`Migrated conversation ${convId} with ${messages.length} messages`
);
}
console.log(
`Migrated ${migratedCount} conversations from localStorage to IndexedDB`
);
localStorage.setItem('migratedToIDB', '1');
});
}
+53 -1
View File
@@ -5,11 +5,46 @@ export interface TimingReport {
predicted_ms: number;
}
/**
* What is conversation "branching"? It is a feature that allows the user to edit an old message in the history, while still keeping the conversation flow.
* Inspired by ChatGPT / Claude / Hugging Chat where you edit a message, a new branch of the conversation is created, and the old message is still visible.
*
* We use the same node-based structure like other chat UIs, where each message has a parent and children. A "root" message is the first message in a conversation, which will not be displayed in the UI.
*
* root
* message 1
* message 2
* message 3
* message 4
* message 5
*
* In the above example, assuming that user wants to edit message 2, a new branch will be created:
*
* message 2
* message 3
* message 6
*
* Message 2 and 6 are siblings, and message 6 is the new branch.
*
* We only need to know the last node (aka leaf) to get the current branch. In the above example, message 5 is the leaf of branch containing message 4 and 5.
*
* For the implementation:
* - StorageUtils.getMessages() returns list of all nodes
* - StorageUtils.filterByLeafNodeId() filters the list of nodes from a given leaf node
*/
// Note: the term "message" and "node" are used interchangeably in this context
export interface Message {
id: number;
convId: string;
type: 'text' | 'root';
timestamp: number; // timestamp from Date.now()
role: 'user' | 'assistant' | 'system';
content: string;
timings?: TimingReport;
// node based system for branching
parent: Message['id'];
children: Message['id'][];
}
export type APIMessage = Pick<Message, 'role' | 'content'>;
@@ -17,9 +52,26 @@ export type APIMessage = Pick<Message, 'role' | 'content'>;
export interface Conversation {
id: string; // format: `conv-{timestamp}`
lastModified: number; // timestamp from Date.now()
messages: Message[];
currNode: Message['id']; // the current message node being viewed
name: string;
}
export interface ViewingChat {
conv: Readonly<Conversation>;
messages: Readonly<Message[]>;
}
export type PendingMessage = Omit<Message, 'content'> & {
content: string | null;
};
export enum CanvasType {
PY_INTERPRETER,
}
export interface CanvasPyInterpreter {
type: CanvasType.PY_INTERPRETER;
content: string;
}
export type CanvasData = CanvasPyInterpreter;
+4
View File
@@ -72,5 +72,9 @@ export default defineConfig({
proxy: {
'/v1': 'http://localhost:8080',
},
headers: {
'Cross-Origin-Embedder-Policy': 'require-corp',
'Cross-Origin-Opener-Policy': 'same-origin',
},
},
});
-2
View File
@@ -10,8 +10,6 @@ extern "C" {
#define GGML_VK_NAME "Vulkan"
#define GGML_VK_MAX_DEVICES 16
GGML_BACKEND_API void ggml_vk_instance_init(void);
// backend API
GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
+1 -1
View File
@@ -198,7 +198,7 @@
#ifndef __GNUC__
# define GGML_ATTRIBUTE_FORMAT(...)
#elif defined(__MINGW32__)
#elif defined(__MINGW32__) && !defined(__clang__)
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
-2
View File
@@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
GGML_TABLE_END()
//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
@@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
GGML_TABLE_END()
//#endif
GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
+704 -17
View File
@@ -742,7 +742,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__wasm_simd128__)
#elif defined __wasm_simd128__
for (int i = 0; i < nb; i++) {
v128_t srcv [8];
v128_t asrcv[8];
@@ -1030,7 +1030,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
}
#elif defined(__wasm_simd128__)
#elif defined __wasm_simd128__
for (int i = 0; i < nb; i++) {
v128_t srcv [8];
v128_t asrcv[8];
@@ -1644,7 +1644,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
//===================================== Q8_K ==============================================
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
#ifdef __wasm_simd128__
assert(k % QK_K == 0);
const int64_t nb = k / QK_K;
block_q8_K * restrict yc = y; // Cast to proper type
for (int i = 0; i < nb; i++) {
const float * x_block = x + i * QK_K;
v128_t min_vec = wasm_v128_load(x_block);
v128_t max_vec = min_vec;
for (int j = 4; j < QK_K; j += 4) {
v128_t x_vec = wasm_v128_load(x_block + j);
max_vec = wasm_f32x4_pmax(max_vec, x_vec);
min_vec = wasm_f32x4_pmin(min_vec, x_vec);
}
max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
float max = wasm_f32x4_extract_lane(max_vec, 0);
float min = wasm_f32x4_extract_lane(min_vec, 0);
float amax = -min > max ? min : max;
if (amax == 0.0f) {
yc[i].d = 0.0f;
const v128_t zero = wasm_i8x16_splat(0);
for (int j = 0; j < QK_K; j += 16) {
wasm_v128_store(yc[i].qs + j, zero);
}
continue;
}
const float iscale = -127.0f / amax;
const v128_t scale_vec = wasm_f32x4_splat(iscale);
// Process 16 elements per iteration
for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
// Load and quantize 16 floats
v128_t x0 = wasm_v128_load(x_block + j);
v128_t x1 = wasm_v128_load(x_block + j + 4);
v128_t x2 = wasm_v128_load(x_block + j + 8);
v128_t x3 = wasm_v128_load(x_block + j + 12);
v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
// Convert to i32 with saturation
v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
// Pack into 16 i8 values
v128_t i8 = wasm_i8x16_narrow_i16x8(
wasm_i16x8_narrow_i32x4(i0, i1),
wasm_i16x8_narrow_i32x4(i2, i3)
);
wasm_v128_store(yc[i].qs + j, i8);
// Calculate bsums using SIMD
v128_t sum16 = wasm_i16x8_add(
wasm_i16x8_extend_low_i8x16(i8),
wasm_i16x8_extend_high_i8x16(i8)
);
v128_t sum32 = wasm_i32x4_add(
wasm_i32x4_extend_low_i16x8(sum16),
wasm_i32x4_extend_high_i16x8(sum16)
);
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
}
yc[i].d = 1.0f / iscale;
}
#else
quantize_row_q8_K_ref(x, y, k);
#endif
}
//===================================== Dot products =================================
@@ -2002,6 +2082,94 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined __wasm_simd128__
v128_t sumv = wasm_f32x4_splat(0.0f);
const v128_t m4b = wasm_i8x16_splat(0x0F);
const v128_t s8b = wasm_i8x16_splat(0x8);
for (; ib + 1 < nb; ib += 2) {
const block_q4_0 * restrict x0 = &x[ib];
const block_q4_0 * restrict x1 = &x[ib + 1];
const block_q8_0 * restrict y0 = &y[ib];
const block_q8_0 * restrict y1 = &y[ib + 1];
// Load and process x0
v128_t v0_0 = wasm_v128_load(x0->qs);
v128_t v0_0l = wasm_v128_and(v0_0, m4b);
v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
// Load y0 vectors
v128_t y0_l = wasm_v128_load(y0->qs);
v128_t y0_h = wasm_v128_load(y0->qs + 16);
// Extend to i16x8 and compute dot products
v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
v128_t dp0 = wasm_i32x4_add(
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(dx0l, dy0ll),
wasm_i32x4_dot_i16x8(dx0h, dy0lh)
),
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
)
);
// Load and process x1
v128_t v0_1 = wasm_v128_load(x1->qs);
v128_t v0_1l = wasm_v128_and(v0_1, m4b);
v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
// Load y1 vectors
v128_t y1_l = wasm_v128_load(y1->qs);
v128_t y1_h = wasm_v128_load(y1->qs + 16);
// Extend to i16x8 and compute dot products
v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
v128_t dp1 = wasm_i32x4_add(
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(dx1l, dy1ll),
wasm_i32x4_dot_i16x8(dx1h, dy1lh)
),
wasm_i32x4_add(
wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
)
);
// Accumulate results with scaling
float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d);
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
}
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@@ -2688,10 +2856,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__wasm_simd128__)
#elif defined __wasm_simd128__
v128_t sumv = wasm_f32x4_splat(0.0f);
uint32_t qh;
uint32_t qh_;
uint64_t tmp[4];
// TODO: check if unrolling this is better
@@ -2702,12 +2870,12 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const v128_t m4b = wasm_i8x16_splat(0x0F);
// extract the 5th bit
memcpy(&qh, x0->qh, sizeof(qh));
memcpy(&qh_, x0->qh, sizeof(qh_));
tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_1[(qh >> 24) ];
tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
tmp[3] = table_b2b_1[(qh_ >> 24) ];
const v128_t qhl = wasm_v128_load(tmp + 0);
const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3049,12 +3217,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
#elif defined(__wasm_simd128__)
#elif defined __wasm_simd128__
v128_t sumv = wasm_f32x4_splat(0.0f);
float summs = 0.0f;
uint32_t qh;
uint32_t qh_;
uint64_t tmp[4];
// TODO: check if unrolling this is better
@@ -3067,12 +3235,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
const v128_t m4b = wasm_i8x16_splat(0x0F);
// extract the 5th bit
memcpy(&qh, x0->qh, sizeof(qh));
memcpy(&qh_, x0->qh, sizeof(qh_));
tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_0[(qh >> 24) ];
tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
tmp[3] = table_b2b_0[(qh_ >> 24) ];
const v128_t qhl = wasm_v128_load(tmp + 0);
const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3565,6 +3733,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined __wasm_simd128__
v128_t sumv = wasm_f32x4_splat(0.0f);
for (; ib < nb; ++ib) {
const block_q8_0 * restrict x0 = &x[ib];
const block_q8_0 * restrict y0 = &y[ib];
const v128_t x0_0 = wasm_v128_load(x0->qs);
const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
const v128_t y0_0 = wasm_v128_load(y0->qs);
const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
// Extend 8-bit to 16-bit
const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
// Compute dot products
const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
// Sum all dot products
const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
// Convert to float and accumulate
const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
}
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@@ -4439,6 +4646,106 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc);
#elif defined __wasm_simd128__
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * q2 = x[i].qs;
const int8_t * q8 = y[i].qs;
const uint8_t * sc = x[i].scales;
// Vectorized summs calculation
v128_t summs_vec = wasm_i32x4_splat(0);
{
v128_t sc_vec = wasm_v128_load(sc);
v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
summs_vec = wasm_i32x4_add(
wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
wasm_i32x4_dot_i16x8(sc_high, bsums2)),
summs_vec
);
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
}
int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
// Vectorized isum calculation
int32_t isum = 0;
const uint8_t * sc_ptr = sc;
const int k_iters = QK_K/128;
for (int k = 0; k < k_iters; ++k) {
v128_t isum_vec = wasm_i32x4_splat(0);
int shift = 0;
for (int j = 0; j < 4; ++j) {
const int d0 = (sc_ptr[0] & 0xF);
const int d1 = (sc_ptr[1] & 0xF);
sc_ptr += 2;
// Process first 16 elements
v128_t q2_0 = wasm_v128_load(q2);
v128_t q8_0 = wasm_v128_load(q8);
v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
// Process next 16 elements
v128_t q2_1 = wasm_v128_load(q2 + 16);
v128_t q8_1 = wasm_v128_load(q8 + 16);
v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
// Calculate dot products
v128_t p0 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q8_0),
wasm_i16x8_extend_low_i8x16(q2_bits_0)
);
v128_t p1 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q8_0),
wasm_i16x8_extend_high_i8x16(q2_bits_0)
);
v128_t p2 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q8_1),
wasm_i16x8_extend_low_i8x16(q2_bits_1)
);
v128_t p3 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q8_1),
wasm_i16x8_extend_high_i8x16(q2_bits_1)
);
// Accumulate scaled results
v128_t scaled = wasm_i32x4_add(
wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
);
isum_vec = wasm_i32x4_add(isum_vec, scaled);
q8 += 32;
shift += 2;
}
q2 += 32;
// Horizontal sum of isum_vec
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
isum += wasm_i32x4_extract_lane(isum_vec, 0);
}
const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
sumf += dall * isum - dmin * summs;
}
*s = sumf;
#elif defined __riscv_v_intrinsic
float sumf = 0;
@@ -5121,6 +5428,94 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc);
#elif defined __wasm_simd128__
int8_t aux8[QK_K];
float sums[8] = {0};
uint32_t auxs[4];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const uint8_t * restrict q3 = x[i].qs;
const uint8_t * restrict hm = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
// Process blocks with SIMD
int8_t * a = aux8;
uint8_t m = 1;
for (int j = 0; j < QK_K; j += 128) {
for (int shift = 0; shift <= 6; shift += 2) {
v128_t v_m = wasm_i8x16_splat(m);
for (int l = 0; l < 32; l += 16) {
v128_t v_q3 = wasm_v128_load(q3 + l);
v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
v128_t v_hm = wasm_v128_load(hm + l);
v128_t v_mask = wasm_v128_and(v_hm, v_m);
v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
wasm_v128_store(a + l, v_low2);
}
a += 32;
m <<= 1;
}
q3 += 32;
}
// Extract scales
memcpy(auxs, x[i].scales, 12);
uint32_t tmp = auxs[2];
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
const int8_t * scales = (const int8_t *)auxs;
// SIMD dot product with register accumulators
v128_t v_acc0 = wasm_i32x4_splat(0);
v128_t v_acc1 = wasm_i32x4_splat(0);
a = aux8;
for (int j = 0; j < QK_K/16; ++j) {
const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
// Process 16 elements per iteration
for (int k = 0; k < 2; ++k) {
const v128_t v_q8 = wasm_i16x8_load8x8(q8);
const v128_t v_a = wasm_i16x8_load8x8(a);
v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
v_prod = wasm_i16x8_mul(v_prod, v_scale);
v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
q8 += 8;
a += 8;
}
}
// Accumulate results
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const v128_t v_d = wasm_f32x4_splat(d);
v128_t v_sum = wasm_f32x4_add(
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
);
// Accumulate into sums vector
wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
}
// Horizontal sum
v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
sumf = wasm_f32x4_extract_lane(v_sum, 0) +
wasm_f32x4_extract_lane(v_sum, 1) +
wasm_f32x4_extract_lane(v_sum, 2) +
wasm_f32x4_extract_lane(v_sum, 3);
*s = sumf;
#elif defined __riscv_v_intrinsic
uint32_t aux[3];
@@ -5646,7 +6041,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
}
}
*s = sumf;
#elif __ARM_NEON
#elif defined __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0);
@@ -5709,6 +6104,107 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = sumf;
#elif defined __wasm_simd128__
const uint8_t * scales = (const uint8_t*)&utmp[0];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
// Process scales and mins
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
// Sum mins * q8sums
int32_t sumi = 0;
const int16_t * restrict q8sums = y[i].bsums;
const uint8_t * m = (const uint8_t *)&utmp[2];
for (int j = 0; j < 16; j += 2) {
sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
}
sumf -= dmin * sumi;
int32_t sumi1 = 0;
int32_t sumi2 = 0;
for (int j = 0; j < QK_K/64; ++j) {
// Load 64 4-bit weights (32 bytes)
const v128_t q4x0 = wasm_v128_load(q4);
const v128_t q4x1 = wasm_v128_load(q4 + 16);
q4 += 32;
// Split into low/high nibbles
const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
// Load 64 8-bit values (64 bytes)
const v128_t q8x0 = wasm_v128_load(q8);
const v128_t q8x1 = wasm_v128_load(q8 + 16);
const v128_t q8x2 = wasm_v128_load(q8 + 32);
const v128_t q8x3 = wasm_v128_load(q8 + 48);
q8 += 64;
// Low nibble products
v128_t vacc1 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q4l0),
wasm_i16x8_extend_low_i8x16(q8x0)
);
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q4l0),
wasm_i16x8_extend_high_i8x16(q8x0)
));
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q4l1),
wasm_i16x8_extend_low_i8x16(q8x1)
));
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q4l1),
wasm_i16x8_extend_high_i8x16(q8x1)
));
// High nibble products
v128_t vacc2 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q4h0),
wasm_i16x8_extend_low_i8x16(q8x2)
);
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q4h0),
wasm_i16x8_extend_high_i8x16(q8x2)
));
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q4h1),
wasm_i16x8_extend_low_i8x16(q8x3)
));
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q4h1),
wasm_i16x8_extend_high_i8x16(q8x3)
));
// Accumulate scaled results
int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
sumi1 += vacc1_sum * scales[2*j];
int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
sumi2 += vacc2_sum * scales[2*j+1];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
#elif defined __AVX2__
const __m256i m4 = _mm256_set1_epi8(0xF);
@@ -6459,6 +6955,118 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc) + summs;
#elif defined __wasm_simd128__
//const uint8_t * scales = (const uint8_t*)&utmp[0];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
const uint8_t * restrict q5 = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
// Process scales and mins
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
// Sum mins * q8sums
int32_t sumi_mins = 0;
const int16_t * restrict q8sums = y[i].bsums;
const uint8_t * m = (const uint8_t *)&utmp[2];
for (int j = 0; j < 16; j += 2) {
sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
}
sumf -= dmin * sumi_mins; // Correct subtraction
v128_t qh0 = wasm_v128_load(qh);
v128_t qh1 = wasm_v128_load(qh + 16);
const uint8_t * sc = (const uint8_t *)utmp;
int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) {
const int shift = j * 2;
v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
v128_t q5_0 = wasm_v128_load(q5);
v128_t q5_1 = wasm_v128_load(q5 + 16);
q5 += 32;
v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
v128_t q8_0 = wasm_v128_load(q8);
v128_t q8_1 = wasm_v128_load(q8 + 16);
v128_t q8_2 = wasm_v128_load(q8 + 32);
v128_t q8_3 = wasm_v128_load(q8 + 48);
q8 += 64;
// Process low quants
v128_t pl0 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q5l_0),
wasm_i16x8_extend_low_i8x16(q8_0)
);
pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q5l_0),
wasm_i16x8_extend_high_i8x16(q8_0)
));
v128_t pl1 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q5l_1),
wasm_i16x8_extend_low_i8x16(q8_1)
);
pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q5l_1),
wasm_i16x8_extend_high_i8x16(q8_1)
));
v128_t sum_low = wasm_i32x4_add(pl0, pl1);
// Process high quants
v128_t ph0 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q5h_0),
wasm_i16x8_extend_low_i8x16(q8_2)
);
ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q5h_0),
wasm_i16x8_extend_high_i8x16(q8_2)
));
v128_t ph1 = wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_low_i8x16(q5h_1),
wasm_i16x8_extend_low_i8x16(q8_3)
);
ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
wasm_i16x8_extend_high_i8x16(q5h_1),
wasm_i16x8_extend_high_i8x16(q8_3)
));
v128_t sum_high = wasm_i32x4_add(ph0, ph1);
// Accumulate with scale factors
int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
sumi += sl * sc[2*j] + sh * sc[2*j+1];
}
sumf += d * sumi;
}
*s = sumf;
#elif defined __riscv_v_intrinsic
const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -7122,6 +7730,85 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc);
#elif defined __wasm_simd128__
int8_t aux8[QK_K] __attribute__((aligned(16)));
int32_t aux32[8] __attribute__((aligned(16))) = {0};
float sums[8] __attribute__((aligned(16))) = {0};
for (int i = 0; i < nb; ++i) {
// Unpack 6-bit quantized data into aux8 (unchanged)
const uint8_t * restrict q4 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
int8_t * a = aux8;
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
}
a += 128;
q4 += 64;
qh += 32;
}
const int8_t * restrict a_ptr = aux8;
const int8_t * restrict q8 = y[i].qs;
v128_t acc0 = wasm_i32x4_splat(0);
v128_t acc1 = wasm_i32x4_splat(0);
for (int j = 0; j < QK_K/16; ++j) {
const int scale = x[i].scales[j];
const v128_t vscale = wasm_i32x4_splat(scale);
// Load 16 elements from a and q8
const v128_t a_vec = wasm_v128_load(a_ptr);
const v128_t q8_vec = wasm_v128_load(q8);
// Process low 8 elements
v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
// Process high 8 elements
v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
// Scale and accumulate
prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
a_ptr += 16;
q8 += 16;
}
// Store accumulated results
wasm_v128_store(&aux32[0], acc0);
wasm_v128_store(&aux32[4], acc1);
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
for (int l = 0; l < 8; ++l) {
sums[l] += d * aux32[l];
}
}
// Sum final results
float sumf = 0;
for (int l = 0; l < 8; ++l) {
sumf += sums[l];
}
*s = sumf;
#elif defined __riscv_v_intrinsic
float sumf = 0;
+190 -91
View File
@@ -7,10 +7,8 @@
#include "ggml-cpu-impl.h"
#include "ggml-cpu.h"
#include "ggml-impl.h"
#include "ggml-quants.h"
#include "ggml-cpu-quants.h"
#include "ggml-threading.h"
#include "amx/amx.h"
#include "ggml.h"
#if defined(_MSC_VER) || defined(__MINGW32__)
@@ -1291,7 +1289,7 @@ struct ggml_threadpool {
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
atomic_int GGML_CACHE_ALIGN n_barrier;
atomic_int GGML_CACHE_ALIGN n_barrier_passed;
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
// these are atomic as an annotation for thread-sanitizer
atomic_bool stop; // Used for stopping the threadpool altogether
@@ -7490,6 +7488,7 @@ UseGgmlGemm1:;
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t nbw0 = ggml_type_size(vec_dot_type);
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
const size_t nbw2 = nbw1*ne11;
const size_t nbw3 = nbw2*ne12;
@@ -7497,6 +7496,7 @@ UseGgmlGemm1:;
assert(params->wsize >= ne13*nbw3);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
#if 0
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
@@ -7506,6 +7506,20 @@ UseGgmlGemm1:;
}
}
}
#else
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
size_t bs = ggml_blck_size(vec_dot_type);
int64_t ne10_block_start = (ith * ne10/bs) / nth;
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
(ne10_block_end - ne10_block_start) * bs);
}
}
}
#endif
}
if (ith == 0) {
@@ -7593,7 +7607,6 @@ UseGgmlGemm2:;
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
num_rows_per_vec_dot = 1;
}
ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
if (nth >= nchunk0 * nchunk1) {
@@ -7606,6 +7619,84 @@ UseGgmlGemm2:;
// ggml_compute_forward_mul_mat_id
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
static void ggml_compute_forward_mul_mat_id_one_chunk(
struct ggml_tensor * dst,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * ids,
const int64_t cur_a,
const int64_t ir0_start,
const int64_t ir0_end,
const int64_t ir1_start,
const int64_t ir1_end,
const char * src0_cur,
const struct mmid_row_mapping * matrix_rows,
const size_t row_size,
const bool src1_cont,
const void * wdata) {
GGML_TENSOR_BINARY_OP_LOCALS
const enum ggml_type type = src0->type;
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
float tmp[16];
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
const int64_t _i12 = ir1; // logical row index for this expert
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
const int id = row_mapping.i1; // selected expert index
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11)*row_size
: (i11*nb11 + i12*nb12));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
}
}
}
}
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
void * ptr = *p;
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
*p = (void *) ((char *) ptr + size);
return ptr;
}
static void ggml_compute_forward_mul_mat_id(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -7623,7 +7714,6 @@ static void ggml_compute_forward_mul_mat_id(
const bool src1_cont = ggml_is_contiguous(src1);
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
@@ -7641,21 +7731,27 @@ static void ggml_compute_forward_mul_mat_id(
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
void * wdata_cur = params->wdata;
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
if (src1->type != vec_dot_type) {
incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
}
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
int64_t * matrix_row_counts = // [n_as]
incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t nbw0 = ggml_type_size(vec_dot_type);
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
const size_t nbw2 = nbw1*ne11;
const size_t nbw3 = nbw2*ne12;
@@ -7663,19 +7759,32 @@ static void ggml_compute_forward_mul_mat_id(
assert(params->wsize >= ne13*nbw3);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
#if 0
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
ne10);
}
}
}
#else
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
size_t bs = ggml_blck_size(vec_dot_type);
int64_t ne10_block_start = (ith * ne10/bs) / nth;
int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
(ne10_block_end - ne10_block_start) * bs);
}
}
}
#endif
}
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
if (ith == 0) {
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -7693,9 +7802,14 @@ static void ggml_compute_forward_mul_mat_id(
}
}
// reset current_chunk
for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
*current_chunk_ctr = nth;
}
ggml_barrier(params->threadpool);
// compute each matrix multiplication in sequence
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
@@ -7703,84 +7817,64 @@ static void ggml_compute_forward_mul_mat_id(
continue;
}
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const char * src0_cur = (const char *) src0->data + cur_a * nb02;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows
const int64_t nr0 = ne01;
const int64_t nr1 = cne1;
// distribute the thread work across the inner or outer loop based on which one is larger
int chunk_size = 16;
if (nr0 == 1 || nr1 == 1) {
chunk_size = 64;
}
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
#if defined(__aarch64__)
// disable for ARM
const bool disable_chunking = true;
#else
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
#endif // defined(__aarch64__)
const int64_t ith0 = ith % nth0;
const int64_t ith1 = ith / nth0;
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
nchunk0 = nr0 > nr1 ? nth : 1;
nchunk1 = nr0 > nr1 ? 1 : nth;
}
const int64_t ir010 = dr0*ith0;
const int64_t ir011 = MIN(ir010 + dr0, nr0);
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
const int64_t ir110 = dr1*ith1;
const int64_t ir111 = MIN(ir110 + dr1, nr1);
int current_chunk = ith;
// threads with no work simply yield (not sure if it helps)
//if (ir010 >= ir011 || ir110 >= ir111) {
// sched_yield();
// continue;
//}
atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
// block-tiling attempt
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
// attempt to reduce false-sharing (does not seem to make a difference)
float tmp[16];
const int64_t ir0_start = dr0 * ith0;
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t _i12 = ir1; // logical row index for this expert
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
const int id = row_mapping.i1; // selected expert index
ggml_compute_forward_mul_mat_id_one_chunk(
dst, src0, src1, ids, cur_a,
ir0_start, ir0_end, ir1_start, ir1_end,
src0_cur, matrix_rows, row_size, src1_cont, wdata
);
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11)*row_size
: (i11*nb11 + i12*nb12));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
if (nth >= nchunk0 * nchunk1) {
break;
}
current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
}
}
#undef MMID_MATRIX_ROW
}
// ggml_compute_forward_out_prod
@@ -9074,10 +9168,6 @@ static void ggml_compute_forward_clamp_f32(
const struct ggml_tensor * src0 = dst->src[0];
if (params->ith != 0) {
return;
}
float min;
float max;
memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
@@ -13717,14 +13807,19 @@ struct ggml_cplan ggml_graph_plan(
cur = 0;
const struct ggml_tensor * src0 = node->src[0];
const struct ggml_tensor * src1 = node->src[1];
const struct ggml_tensor * ids = node->src[2];
const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
}
const int n_as = src0->ne[2];
cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
// src1
if (src1->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);
}
// matrix_row_counts
cur += n_as * sizeof(int64_t) + sizeof(int64_t);
// matrix_rows
cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
// atomic_current_chunk
cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
} break;
case GGML_OP_OUT_PROD:
{
@@ -13856,9 +13951,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
tp->ec = GGML_STATUS_ABORTED;
}
ggml_barrier(state->threadpool);
if (node_n + 1 < cgraph->n_nodes) {
ggml_barrier(state->threadpool);
}
}
ggml_barrier(state->threadpool);
return 0;
}
+2 -5
View File
@@ -284,14 +284,14 @@ struct ggml_backend_cpu_device_context {
&hKey) == ERROR_SUCCESS) {
DWORD cpu_brand_size = 0;
if (RegQueryValueExA(hKey,
TEXT("ProcessorNameString"),
"ProcessorNameString",
NULL,
NULL,
NULL,
&cpu_brand_size) == ERROR_SUCCESS) {
description.resize(cpu_brand_size);
if (RegQueryValueExA(hKey,
TEXT("ProcessorNameString"),
"ProcessorNameString",
NULL,
NULL,
(LPBYTE)&description[0], // NOLINT
@@ -534,9 +534,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
if (ggml_cpu_has_dotprod()) {
features.push_back({ "DOTPROD", "1" });
}
if (ggml_cpu_has_matmul_int8()) {
features.push_back({ "MATMUL_INT8", "1" });
}
if (ggml_cpu_get_sve_cnt() > 0) {
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
+62 -7
View File
@@ -71,6 +71,47 @@
#define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220
#ifdef __CUDA_ARCH_LIST__
constexpr bool ggml_cuda_has_arch_impl(int) {
return false;
}
template<class ... Archs>
constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
}
constexpr bool ggml_cuda_has_arch(const int arch) {
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
}
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
if (cur == 0) {
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
}
return cur;
}
template<class ... Archs>
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
if (first <= arch && first > cur) {
return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
} else {
return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
}
}
constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
}
#else
static int ggml_cuda_highest_compiled_arch(const int arch) {
return arch;
}
#endif // __CUDA_ARCH_LIST__
// ---------------------------------------------------------------------------------------------------------
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#if defined(_MSC_VER)
@@ -124,11 +165,11 @@ static const char * cu_get_error_str(CUresult err) {
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else
#define GGML_CUDA_ASSUME(x)
#endif // CUDART_VERSION >= 11100
#endif // CUDART_VERSION >= 11010
#ifdef GGML_CUDA_F16
typedef half dfloat; // dequantize float
@@ -162,18 +203,32 @@ typedef float2 dfloat2;
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
static constexpr bool fast_fp16_available(const int cc) {
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
}
static bool fast_fp16_available(const int cc) {
return fp16_available(cc) && cc != 610;
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fast_fp16_hardware_available(const int cc) {
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
}
// Any FP16 tensor cores are available.
static constexpr bool fp16_mma_available(const int cc) {
// Any FP16 tensor core instructions are available for ggml code.
static bool fp16_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fp16_mma_hardware_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
}
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static constexpr bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
static bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
+1 -1
View File
@@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
return dequantize_block_q8_0_f16_cuda;
}
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+13 -11
View File
@@ -178,11 +178,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
int major_version = 0;
size_t version_length = 0;
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
std::string version(version_length, '\0');
std::vector<char> version(version_length+1, '\0');
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
version.resize(::strlen(version.c_str()));
version.resize(::strlen(version.data()));
int parsed_value = 0;
if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) {
if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
major_version = parsed_value;
}
}
@@ -1867,14 +1867,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const int cc = ggml_cuda_info().devices[id].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
}
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
}
// debug helpers
@@ -2840,7 +2840,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
return false;
}
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
if (err != cudaSuccess) {
// clear the error
@@ -2852,8 +2852,10 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
}
return true;
#else
GGML_UNUSED(buffer);
GGML_UNUSED(size);
return false;
#endif
#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
}
void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
@@ -3205,8 +3207,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+2 -2
View File
@@ -16,7 +16,7 @@
#include "common.cuh"
#if CUDART_VERSION >= 11800
#if CUDART_VERSION >= 11080
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
@@ -50,7 +50,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
return ret_low | ret_high;
}
#endif // CUDART_VERSION >= 11800
#endif // CUDART_VERSION >= 11080
template <typename T>
+6 -5
View File
@@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q(
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
const int cc = ggml_cuda_info().devices[id].cc;
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
@@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q(
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
switch (src0->type) {
@@ -136,7 +137,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return true;
}
if (cc < GGML_CUDA_CC_DP4A) {
if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
return false;
}
@@ -145,8 +146,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
#endif //GGML_CUDA_FORCE_MMQ
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
+8 -6
View File
@@ -86,12 +86,13 @@ struct tile_x_sizes {
int sc;
};
static constexpr int get_mmq_x_max_host(const int cc) {
static int get_mmq_x_max_host(const int cc) {
return new_mma_available(cc) ? 128 :
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
#ifdef GGML_CUDA_FORCE_MMQ
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
128 : 64;
#else
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
MMQ_DP4A_MAX_BATCH_SIZE : 64;
#endif // GGML_CUDA_FORCE_MMQ
}
@@ -119,8 +120,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
#endif // NEW_MMA_AVAILABLE
}
static constexpr int get_mmq_y_host(const int cc) {
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
static int get_mmq_y_host(const int cc) {
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
}
static constexpr __device__ int get_mmq_y_device() {
@@ -2828,7 +2830,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc);
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
int mmq_x_best = 0;
int nparts_best = INT_MAX;
+2 -2
View File
@@ -1,6 +1,6 @@
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
#define USE_CUB
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
#ifdef USE_CUB
#include <cub/cub.cuh>
+6 -7
View File
@@ -103,11 +103,10 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
auto global_mem_size = prop.get_global_mem_size()/1000000;
std::string xmx = gpu_has_xmx(device) ? "yes" : "no";
GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|%14s|\n", id, device_type.c_str(),
GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
name.c_str(), version.c_str(), prop.get_max_compute_units(),
prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str(), xmx.c_str());
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
}
void ggml_backend_sycl_print_sycl_devices() {
@@ -118,16 +117,16 @@ void ggml_backend_sycl_print_sycl_devices() {
GGML_LOG_INFO(
"| | | | "
" |Max | |Max |Global | | XMX |\n");
" |Max | |Max |Global | |\n");
GGML_LOG_INFO(
"| | | | "
" |compute|Max work|sub |mem | | or |\n");
" |compute|Max work|sub |mem | |\n");
GGML_LOG_INFO(
"|ID| Device Type| "
"Name|Version|units |group |group|size | Driver version| Tensor Cores |\n");
"Name|Version|units |group |group|size | Driver version|\n");
GGML_LOG_INFO(
"|--|-------------------|---------------------------------------|------"
"-|-------|--------|-----|-------|---------------------|--------------|\n");
"-|-------|--------|-----|-------|---------------------|\n");
for (int id = 0; id < device_count; ++id) {
sycl::device device = dpct::dev_mgr::instance().get_device(id);
+323 -316
View File
@@ -167,6 +167,7 @@ struct vk_device_struct {
uint32_t subgroup_size;
uint32_t shader_core_count;
bool uma;
bool prefer_host_memory;
bool float_controls_rte_fp16;
bool subgroup_size_control;
@@ -184,12 +185,12 @@ struct vk_device_struct {
size_t idx;
bool mul_mat_l;
bool mul_mat_m;
bool mul_mat_s;
bool mul_mat_id_l;
bool mul_mat_id_m;
bool mul_mat_id_s;
bool mul_mat_l[GGML_TYPE_COUNT];
bool mul_mat_m[GGML_TYPE_COUNT];
bool mul_mat_s[GGML_TYPE_COUNT];
bool mul_mat_id_l[GGML_TYPE_COUNT];
bool mul_mat_id_m[GGML_TYPE_COUNT];
bool mul_mat_id_s[GGML_TYPE_COUNT];
// set to true to indicate that some shaders need to be compiled after the dryrun
bool need_compiles {};
@@ -1294,7 +1295,9 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk:
static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
vk_buffer buf;
try {
if (device->uma) {
if (device->prefer_host_memory) {
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
} else if (device->uma) {
// Fall back to host memory type
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
} else {
@@ -1378,7 +1381,33 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
return {64, 64};
};
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
uint32_t lut_size = 0;
switch (src0_type) {
case GGML_TYPE_IQ2_XXS:
lut_size = 8*256;
break;
case GGML_TYPE_IQ2_XS:
lut_size = 8*512;
break;
case GGML_TYPE_IQ2_S:
lut_size = 8*1024;
break;
case GGML_TYPE_IQ3_XXS:
lut_size = 4*256;
break;
case GGML_TYPE_IQ3_S:
lut_size = 4*512;
break;
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
lut_size = 4*16;
break;
default:
break;
}
// Needs to be kept up to date on shader changes
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
@@ -1388,7 +1417,13 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
"mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
return supported;
}
static void ggml_vk_load_shaders(vk_device& device) {
@@ -1472,62 +1507,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_align = 64;
s_align = 32;
// Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
// and tile sizes, this should handle 16KB, 32KB, and 48KB+.
// This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
// But the numbers happen to work out for 32KB shared memory size that when using the medium
// size there's enough room for everything, and we assert for this.
uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
l_warptile = m_warptile;
l_wg_denoms = m_wg_denoms;
shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
}
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
// assert mul_mat_mat_id shaders will fit.
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
}
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
l_warptile_mmq = m_warptile_mmq;
l_mmq_wg_denoms = m_mmq_wg_denoms;
} else {
l_warptile_mmq = s_warptile_mmq;
l_mmq_wg_denoms = s_mmq_wg_denoms;
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
ggml_type t = (ggml_type)i;
// Disable medium and large matrix multiplication if not enough shared memory is available
// Check mmq warptiles as the largest configuration
// Throw an error if not enough for any matrix multiplication is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
device->mul_mat_m[i] = false;
device->mul_mat_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
device->mul_mat_l[i] = false;
}
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
}
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
// assert mul_mat_mat_id shaders will fit.
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
}
// Disable medium and large matrix multiplication if not enough shared memory is available
// Check mmq warptiles as the largest configuration
// Throw an error if not enough for any matrix multiplication is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
device->mul_mat_m = false;
device->mul_mat_l = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
device->mul_mat_l = false;
}
// Disable mul_mat_id if not enough shared memory is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
device->mul_mat_id_s = false;
device->mul_mat_id_m = false;
device->mul_mat_id_l = false;
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
device->mul_mat_id_m = false;
device->mul_mat_id_l = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
device->mul_mat_id_l = false;
// Disable mul_mat_id if not enough shared memory is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
device->mul_mat_id_s[i] = false;
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
device->mul_mat_id_l[i] = false;
}
}
}
@@ -1684,119 +1689,116 @@ static void ggml_vk_load_shaders(vk_device& device) {
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat_support) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l) \
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _m) \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _s) \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _l) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
if (device->mul_mat ## ID ## _m) \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
if (device->mul_mat ## ID ## _s) \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->coopmat_acc_f16_support) { \
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
} \
if (device->coopmat_acc_f32_support) { \
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
} \
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
if (device->coopmat_acc_f16_support) {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
if (device->coopmat_acc_f16_support) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
if (device->coopmat_acc_f16_support) {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM2
#undef CREATE_MM
@@ -1804,141 +1806,135 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l) \
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m) \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s) \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _l) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
if (device->mul_mat ## ID ## _m) \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
if (device->mul_mat ## ID ## _s) \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MM
} else {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l) \
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m) \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s) \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _l) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
if (device->mul_mat ## ID ## _m) \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
if (device->mul_mat ## ID ## _s) \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM
}
@@ -2206,6 +2202,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->physical_device = physical_devices[dev_num];
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
bool fp16_storage = false;
bool fp16_compute = false;
bool maintenance4_support = false;
@@ -2623,34 +2622,36 @@ static vk_device ggml_vk_get_device(size_t idx) {
// Shaders
// Disable matmul tile sizes early if performance low or not supported
switch (device->vendor_id) {
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
switch (device->vendor_id) {
#ifndef GGML_VULKAN_RUN_TESTS
case VK_VENDOR_ID_AMD:
case VK_VENDOR_ID_INTEL:
device->mul_mat_l = false;
device->mul_mat_m = true;
device->mul_mat_s = true;
device->mul_mat_id_l = false;
device->mul_mat_id_m = true;
device->mul_mat_id_s = true;
break;
case VK_VENDOR_ID_APPLE:
device->mul_mat_l = false;
device->mul_mat_m = true;
device->mul_mat_s = false;
device->mul_mat_id_l = false;
device->mul_mat_id_m = true;
device->mul_mat_id_s = false;
break;
case VK_VENDOR_ID_AMD:
case VK_VENDOR_ID_INTEL:
device->mul_mat_l[i] = false;
device->mul_mat_m[i] = true;
device->mul_mat_s[i] = true;
device->mul_mat_id_l[i] = false;
device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = true;
break;
case VK_VENDOR_ID_APPLE:
device->mul_mat_l[i] = false;
device->mul_mat_m[i] = true;
device->mul_mat_s[i] = false;
device->mul_mat_id_l[i] = false;
device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = false;
break;
#endif
default:
device->mul_mat_l = true;
device->mul_mat_m = true;
device->mul_mat_s = true;
device->mul_mat_id_l = true;
device->mul_mat_id_m = true;
device->mul_mat_id_s = true;
break;
default:
device->mul_mat_l[i] = true;
device->mul_mat_m[i] = true;
device->mul_mat_s[i] = true;
device->mul_mat_id_l[i] = true;
device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = true;
break;
}
}
ggml_vk_load_shaders(device);
@@ -2780,8 +2781,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str());
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -2791,14 +2793,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
void ggml_vk_instance_init() {
static void ggml_vk_instance_init() {
if (vk_instance_initialized) {
return;
}
VK_LOG_DEBUG("ggml_vk_instance_init()");
vk_instance_initialized = true;
uint32_t api_version = vk::enumerateInstanceVersion();
if (api_version < VK_API_VERSION_1_2) {
@@ -2849,6 +2849,7 @@ void ggml_vk_instance_init() {
GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
}
vk_instance.instance = vk::createInstance(instance_create_info);
vk_instance_initialized = true;
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
@@ -2873,7 +2874,7 @@ void ggml_vk_instance_init() {
// Make sure at least one device exists
if (devices.empty()) {
std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
GGML_ABORT("fatal error");
return;
}
// Default to using all dedicated GPUs
@@ -3755,31 +3756,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
return split_k;
}
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
}
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
}
static void ggml_vk_matmul(
@@ -3806,31 +3807,31 @@ static void ggml_vk_matmul(
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
}
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
}
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
}
static void ggml_vk_matmul_id(
@@ -4011,10 +4012,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
@@ -4593,10 +4594,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t y_ne = ne11 * ne10;
const uint64_t d_ne = ne21 * ne20;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@@ -8035,13 +8036,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
ggml_type src0_type = op->src[0]->type;
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
return false;
}
switch (op->src[0]->type) {
switch (src0_type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
@@ -8347,8 +8349,13 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
/* .iface = */ ggml_backend_vk_reg_i,
/* .context = */ nullptr,
};
return &reg;
try {
ggml_vk_instance_init();
return &reg;
} catch (const vk::SystemError& e) {
VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
return nullptr;
}
}
// Extension availability
+1 -1
View File
@@ -1379,7 +1379,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
(t0->nb[3] == t1->nb[3]);
}
// check if t1 can be represented as a repeatition of t0
// check if t1 can be represented as a repetition of t0
bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+3 -2
View File
@@ -1114,11 +1114,12 @@ extern "C" {
};
struct llama_sampler {
struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
const struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
};
// mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
+1 -1
View File
@@ -1 +1 @@
08b538031f7f944e84f472483ef5d26bf5190ead
98a61a0d0b43cba06c3ac1c603813639552a0701
+1 -1
View File
@@ -116,7 +116,7 @@ struct llama_grammar {
llama_partial_utf8 partial_utf8;
// lazy grammars wait for trigger words or tokens before constraining the sampling.
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
// (useful e.g. for tool_choice=required)
bool lazy = false;
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
+6 -6
View File
@@ -6,13 +6,13 @@
#include <vector>
#ifdef __GNUC__
#ifdef __MINGW32__
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
# if defined(__MINGW32__) && !defined(__clang__)
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
# else
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
# endif
#else
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#else
#define LLAMA_ATTRIBUTE_FORMAT(...)
# define LLAMA_ATTRIBUTE_FORMAT(...)
#endif
//
+1
View File
@@ -1,5 +1,6 @@
#pragma once
#include <cstdint>
#include <memory>
#include <vector>
+1 -1
View File
@@ -1275,7 +1275,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const bool use_mmap_buffer = true;
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, use_mmap_buffer ? "true" : "false");
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false");
// build a list of buffer types for the CPU and GPU devices
pimpl->cpu_buft_list = make_cpu_buft_list(devices);
+64 -57
View File
@@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
// llama_sampler API
struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
return new llama_sampler {
/* .iface = */ iface,
/* .ctx = */ ctx,
};
}
const char * llama_sampler_name(const struct llama_sampler * smpl) {
if (!smpl->iface) {
return "(null)";
@@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
}
if (smpl->ctx == nullptr) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ smpl->iface,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}
GGML_ABORT("the sampler does not support cloning");
@@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
};
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain {
/* .params = */ params,
/* .samplers = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
},
};
}
);
}
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
@@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
};
struct llama_sampler * llama_sampler_init_greedy() {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_greedy_i,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}
// dist
@@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}
// softmax
@@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
};
struct llama_sampler * llama_sampler_init_softmax() {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_softmax_i,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}
// top-k
@@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
};
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k {
/* .k = */ k,
},
};
}
);
}
// top-p
@@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
};
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_top_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}
// min-p
@@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
};
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_min_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}
// typical
@@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
};
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}
// temp
@@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
};
struct llama_sampler * llama_sampler_init_temp(float temp) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_temp {
/*.temp = */ temp,
},
};
}
);
}
// temp-ext
@@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
};
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp,
/* .delta = */ delta,
/* .exponent = */ exponent,
},
};
}
);
}
// xtc
@@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p,
@@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}
// mirostat
@@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab,
@@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}
// mirostat v2
@@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed,
@@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}
// grammar
@@ -1528,10 +1535,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
};
}
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_grammar_i,
/* .ctx = */ ctx,
};
/* .ctx = */ ctx
);
}
struct llama_sampler * llama_sampler_init_grammar(
@@ -1678,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n,
@@ -1687,8 +1694,8 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalty_present = */ penalty_present,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {},
},
};
}
);
}
// DRY
@@ -2041,7 +2048,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
}
}
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ context_size,
@@ -2053,8 +2060,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {},
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
},
};
}
);
}
// wrapper for test-sampling.cpp
@@ -2155,14 +2162,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias {
/* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {},
},
};
}
);
}
// infill
@@ -2377,14 +2384,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
};
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ vocab,
/* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512),
},
};
}
);
}
// utils
+7 -5
View File
@@ -8801,12 +8801,14 @@ static int llama_decode_impl(
//llama_synchronize(&lctx);
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + llama_kv_cache_get_padding(cparams))/float(kv_self.n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
llama_kv_cache_defrag(kv_self);
}
@@ -9428,8 +9430,6 @@ static struct llama_model * llama_model_load_from_file_impl(
struct llama_model_params params) {
ggml_time_init();
llama_model * model = new llama_model(params);
unsigned cur_percentage = 0;
if (params.progress_callback == NULL) {
params.progress_callback_user_data = &cur_percentage;
@@ -9447,6 +9447,8 @@ static struct llama_model * llama_model_load_from_file_impl(
};
}
llama_model * model = new llama_model(params);
// create list of devices to use with this model
if (params.devices) {
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
+8 -1
View File
@@ -618,7 +618,14 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
result.reserve(utf8.size());
size_t offset = 0;
while (offset < utf8.size()) {
result.push_back(unicode_cpt_from_utf8(utf8, offset));
try {
result.push_back(unicode_cpt_from_utf8(utf8, offset));
}
catch (const std::invalid_argument & /*ex*/) {
// Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
++offset;
result.emplace_back(0xFFFD); // replacement character
}
}
return result;
}
+4 -4
View File
@@ -697,8 +697,8 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
#ifdef _WIN32
if (!file) {
printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
printf("%s: skipping tests");
printf("failed to create tmpfile(), needs elevated privileges on Windows");
printf("skipping tests");
continue;
}
#else
@@ -1086,8 +1086,8 @@ static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned
#ifdef _WIN32
if (!file) {
printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
printf("%s: skipping tests");
printf("failed to create tmpfile(), needs elevated privileges on Windows");
printf("skipping tests");
return std::make_pair(0, 0);
}
#else