Compare commits

...

16 Commits

Author SHA1 Message Date
Giuseppe Scrivano 1568d13c2c vulkan: implement ABS and NEG (#17245)
* docs: update Vulkan ops

* vulkan: add NEG op

* vulkan: add ABS op

---------

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
2025-11-15 12:00:29 +01:00
Jeff Bolz 439342ea0b vulkan: Use ggml_vk_tensor_subbuffer in mul_mat_vec(id) paths (#17244)
* vulkan: Use ggml_vk_tensor_subbuffer in mul_mat_vec(id) paths

* set allow_misalign
2025-11-15 11:56:15 +01:00
Jeff Bolz 234ae7d7bd vulkan: skip all-negative-inf blocks in FA (#17186) 2025-11-15 10:37:25 +01:00
Jeff Bolz 38eaf32af1 vulkan: change graph_compute to be async and enable get_tensor_async (#17158)
* vulkan: change graph_compute to be async and enable get_tensor_async

This allows some additional CPU/GPU overlap for large pp workloads. Also seems
to help a bit for token gen, maybe getting rid of a small bubble between
graph_compute and get_tensor.

Async set and copy functions seem to be very rarely used, so I didn't enable
them because I didn't have a good way to test them.

The async commands need to be ordered against each other, so put them all on
the compute queue. The non-async commands still use the transfer queue.

The fence for graph_compute/get_tensor_async is submitted and waited on in
ggml_vk_synchronize.

* fix thread safety errors

* teardown context cleanly

* Handle async read to non-pinned dst
2025-11-15 09:06:41 +01:00
Xuan-Son Nguyen 9b17d74ab7 mtmd: add mtmd_log_set (#17268) 2025-11-14 15:56:19 +01:00
Bartowski e1fcf8b09b model : add AfmoeForCausalLM support (#16477)
* Add AFMOE model support

* Update to vocab

* Add model sizing

* Undo Rope change for ARCEE model

* Address review comments

* Update modeling code is_sliding -> use_rope, replace hard-coded logic

* Fix AFMOE tokenizer

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update AFMoE tokenizer class identification to be more unique

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-11-14 13:54:10 +01:00
Marek Hradil jr. 6cd0cf72ce fix : Dangling pointer for non-empty trigger words in lazy grammar construction (#17048)
* fix : Dangling pointer for non-empty trigger words in llama_sampler_init_grammar_impl (#17047)

* Replace 'static' workaround, with keeping variable in scope for longer

* Create std::array directly and pass into llama_grammar_init_impl

* Add back the trigger pattern

* Missed array include
2025-11-14 14:35:26 +02:00
Georgi Gerganov d396b43748 server : fix "can batch with" bug (#17263) 2025-11-14 14:03:45 +02:00
Georgi Gerganov 45c6ef7307 metal : support argsort for ne00 > 1024 (#17247)
* metal : refactor argsort

* cont : sort chunks

* cont : merge sorted buckets

* cont : cleanup
2025-11-14 09:36:06 +02:00
Georgi Gerganov 2606b0adab metal : make the FA extra sizes consistent (#17143) 2025-11-14 09:13:34 +02:00
ixgbe 307772fcda readme : add RVV,ZVFH,ZFH,ZICBOP support for RISC-V (#17259)
Signed-off-by: Wang Yang <yangwang@iscas.ac.cn>
2025-11-14 09:12:56 +02:00
Aleksander Grygier f1bad23f88 Better UX for handling multiple attachments in WebUI (#17246) 2025-11-14 01:19:08 +01:00
Alberto Cabrera Pérez becc4816dd ggml-cpu: handle 3d tensors in repack mat_mul (#17241)
* ggml-cpu: handle 3d tensors in repack mul_mat

* Removed unnecessary branch, removed need for <algorithm>

* Fixed dst_ptr pointer in chunk + clang_format

* GGML_ASSERT to check wdata within bounds

* Accidental ggml.h inclusion

* Improved GGML_ASSERT on wdata boundaries

* Address performance regression in Qwen and llama.cpp due to chunking
2025-11-13 12:53:00 -08:00
Xuan-Son Nguyen c4abcb2457 server: fixing naming conflict res_error (#17243) 2025-11-13 20:53:47 +01:00
Piotr Wilkin (ilintar) 389ac78b26 ggml : add ops SOFTPLUS, EXPM1, TRI, SOLVE_TRI, CUMSUM (#17063)
* Add ops needed for new hybrid models: SOFTPLUS, EXPM1, TRI, SOLVE_TRI, CUMSUM

* Update ggml/include/ggml.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update tests/test-backend-ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Code review

* Whitespace

* Update tests/test-backend-ops.cpp

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* This is actually sigmoid, duh.

* Add CONST, remove TRI_KEEP, other changes from review

* Update tests/test-backend-ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml.c

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml.c

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml-cuda/unary.cu

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* Remove extra script

* Update ggml/src/ggml.c

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* Update tests/test-backend-ops.cpp

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* moving changes from laptop [no ci]

* pre-rebase

* Update tests/test-backend-ops.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update tests/test-backend-ops.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Refactor tests

* ggml : cleanup

* cont : fix ggml_fill srcs

* tests : add note

* ggml : add ggml_fill_inplace

* ggml : add asserts

* ggml : fix ggml_fill constant cast

* cont : ggml_tri minor

* Use TENSOR_LOCALS

* Fix regression from #14596, regenerate

* Don't make commits at night...

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Diego Devesa <slarengh@gmail.com>
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-11-13 20:54:47 +02:00
Ruben Ortlam a19bd6f7ce vulkan: remove shell call from vulkan-shaders-gen tool, revert file check (#17219)
* vulkan: remove shell call from vulkan-shaders-gen tool

* use string vector for command execution

* Fix condition

* use string, remove const_cast

* Fix dependency file quotation on Windows

---------

Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
2025-11-13 14:51:21 +01:00
72 changed files with 49641 additions and 17235 deletions
+1 -1
View File
@@ -9,7 +9,7 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
- **Size**: ~200k+ lines of code across 1000+ files
- **Architecture**: Modular design with main library (`libllama`) and 40+ executable tools/examples
- **Core dependency**: ggml tensor library (vendored in `ggml/` directory)
- **Backends supported**: CPU (AVX/NEON optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
- **License**: MIT
## Build Instructions
+1
View File
@@ -61,6 +61,7 @@ range of hardware - locally and in the cloud.
- Plain C/C++ implementation without any dependencies
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
- AVX, AVX2, AVX512 and AMX support for x86 architectures
- RVV, ZVFH, ZFH and ZICBOP support for RISC-V architectures
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
- Vulkan and SYCL backend support
+1 -5
View File
@@ -355,11 +355,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
}
void common_init() {
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
common_log_add(common_log_main(), level, "%s", text);
}
}, NULL);
llama_log_set(common_log_default_callback, NULL);
#ifdef NDEBUG
const char * build_type = "";
+6
View File
@@ -442,3 +442,9 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
log->set_timestamps(timestamps);
}
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
common_log_add(common_log_main(), level, "%s", text);
}
}
+2
View File
@@ -36,6 +36,8 @@ extern int common_log_verbosity_thold;
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
// the common_log uses an internal worker thread to print/write log messages
// when the worker thread is paused, incoming log messages are discarded
struct common_log;
+78
View File
@@ -1124,6 +1124,9 @@ class TextModel(ModelBase):
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
res = "afmoe"
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
res = "bailingmoe2"
@@ -2533,6 +2536,81 @@ class ArceeModel(LlamaModel):
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
@ModelBase.register("AfmoeForCausalLM")
class AfmoeModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.AFMOE
def set_gguf_parameters(self):
super().set_gguf_parameters()
# MoE parameters
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
self.gguf_writer.add_expert_shared_count(n_shared_experts)
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
# Expert Gating Function
score_func = self.hparams.get("score_func")
if score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
elif score_func is not None:
raise ValueError(f"Unsupported score_function value: {score_func}")
# Route normalization and scaling
if (route_norm := self.hparams.get("route_norm")) is not None:
self.gguf_writer.add_expert_weights_norm(route_norm)
if (route_scale := self.hparams.get("route_scale")) is not None:
self.gguf_writer.add_expert_weights_scale(route_scale)
# Sliding window attention
if (sliding_window := self.hparams.get("sliding_window")) is not None:
self.gguf_writer.add_sliding_window(sliding_window)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Handle expert weights - they're already merged in the HF format
# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# merge the experts into a single 3d tensor
for w_name in ["gate_proj", "up_proj", "down_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename_to_retrieve])
del self._experts[bid][ename_to_retrieve]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register(
"LlavaForConditionalGeneration", # pixtral
"Mistral3ForConditionalGeneration", # mistral small 3.1
+1
View File
@@ -139,6 +139,7 @@ models = [
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
+31 -25
View File
@@ -14,33 +14,36 @@ Legend:
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|-----------|------|------|------|------|------|------|------|------|------|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | | ❌ | ✅ | ❌ | ✅ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| CONV_3D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
@@ -54,25 +57,25 @@ Legend:
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
| IM2COL_3D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ❌ |
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | | ✅ | ✅ | 🟡 | ✅ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | | ✅ | ❌ | ✅ | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | | ✅ | ❌ | ✅ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
@@ -80,15 +83,15 @@ Legend:
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROLL | ❌ | ❌ | ✅ | | ❌ | ❌ | ✅ | ✅ | ❌ |
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | ❌ |
| ROLL | ❌ | ❌ | ✅ | | ❌ | ❌ | ✅ | ✅ | ❌ |
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| SET | ❌ | ❌ | ✅ | | ✅ | ❌ | 🟡 | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | | ✅ | ❌ | 🟡 | ❌ | ❌ |
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
@@ -96,21 +99,24 @@ Legend:
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | | ❌ | ❌ | 🟡 | | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| SWIGLU_OAI | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
| XIELU | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+16067 -5133
View File
File diff suppressed because it is too large Load Diff
+16224 -6894
View File
File diff suppressed because it is too large Load Diff
+14534 -4358
View File
File diff suppressed because it is too large Load Diff
+71
View File
@@ -475,6 +475,7 @@ extern "C" {
GGML_OP_COS,
GGML_OP_SUM,
GGML_OP_SUM_ROWS,
GGML_OP_CUMSUM,
GGML_OP_MEAN,
GGML_OP_ARGMAX,
GGML_OP_COUNT_EQUAL,
@@ -530,6 +531,8 @@ extern "C" {
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,
GGML_OP_TRI,
GGML_OP_FILL,
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
@@ -542,6 +545,7 @@ extern "C" {
GGML_OP_RWKV_WKV6,
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_UNARY,
@@ -576,6 +580,8 @@ extern "C" {
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_EXPM1,
GGML_UNARY_OP_SOFTPLUS,
GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_XIELU,
GGML_UNARY_OP_FLOOR,
@@ -620,6 +626,13 @@ extern "C" {
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
};
enum ggml_tri_type {
GGML_TRI_TYPE_UPPER_DIAG = 0,
GGML_TRI_TYPE_UPPER = 1,
GGML_TRI_TYPE_LOWER_DIAG = 2,
GGML_TRI_TYPE_LOWER = 3
};
struct ggml_init_params {
// memory pool
size_t mem_size; // bytes
@@ -957,6 +970,22 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_expm1(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_expm1_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_softplus(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_softplus_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sin(
struct ggml_context * ctx,
struct ggml_tensor * a);
@@ -983,6 +1012,10 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_cumsum(
struct ggml_context * ctx,
struct ggml_tensor * a);
// mean along rows
GGML_API struct ggml_tensor * ggml_mean(
struct ggml_context * ctx,
@@ -2187,6 +2220,23 @@ extern "C" {
int shift2,
int shift3);
// Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
// zeroes everywhere outside the masked area
GGML_API struct ggml_tensor * ggml_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_tri_type type);
// Fill tensor a with constant c
GGML_API struct ggml_tensor * ggml_fill(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);
GGML_API struct ggml_tensor * ggml_fill_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
@@ -2356,6 +2406,27 @@ extern "C" {
struct ggml_tensor * b,
struct ggml_tensor * state);
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
* without zeroes on the diagonal (i.e. invertible).
* B can have any number of columns, but must have the same number of rows as A
* If A is [n, n] and B is [n, m], then the result will be [n, m] as well
* Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
* where n > 100 sparingly, pre-chunk if necessary.
*
* If left = false, solves xA=B instead
* If lower = false, assumes upper triangular instead
* If uni = true, assumes diagonal of A to be all ones (will override actual values)
*
* TODO: currently only lower, right, non-unitriangular variant is implemented
*/
GGML_API struct ggml_tensor * ggml_solve_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
bool left,
bool lower,
bool uni);
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
+22
View File
@@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_sum_rows(params, tensor);
} break;
case GGML_OP_CUMSUM:
{
ggml_compute_forward_cumsum(params, tensor);
} break;
case GGML_OP_MEAN:
{
ggml_compute_forward_mean(params, tensor);
@@ -1927,6 +1931,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_leaky_relu(params, tensor);
} break;
case GGML_OP_TRI:
{
ggml_compute_forward_tri(params, tensor);
} break;
case GGML_OP_FILL:
{
ggml_compute_forward_fill(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
ggml_compute_forward_flash_attn_ext(params, tensor);
@@ -1982,6 +1994,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rwkv_wkv7(params, tensor);
} break;
case GGML_OP_SOLVE_TRI:
{
ggml_compute_forward_solve_tri(params, tensor);
} break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
@@ -2140,6 +2156,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_ACC:
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
case GGML_OP_FILL:
{
n_tasks = n_threads;
} break;
@@ -2157,6 +2176,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
n_tasks = 1;
} break;
case GGML_OP_COUNT_EQUAL:
case GGML_OP_SOLVE_TRI:
{
n_tasks = n_threads;
} break;
@@ -2179,6 +2199,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
+208 -2
View File
@@ -9,6 +9,7 @@
#include <cfloat>
#include <algorithm>
#include <cmath>
#include <functional>
// ggml_compute_forward_dup
@@ -1395,6 +1396,56 @@ void ggml_compute_forward_sum(
}
}
// ggml_compute_forward_cumsum
static void ggml_compute_forward_cumsum_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(dst->nb[0] == sizeof(float));
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
const auto [ir0, ir1] = get_thread_range(params, src0);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
}
}
void ggml_compute_forward_cumsum(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_cumsum_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_sum_rows
static void ggml_compute_forward_sum_rows_f32(
@@ -2141,6 +2192,83 @@ static void ggml_compute_forward_gelu(
}
}
// ggml_compute_fill
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
const float c = ggml_get_op_params_f32(dst, 0);
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
const auto [ir0, ir1] = get_thread_range(params, dst);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne2*ne1);
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
ggml_vec_set_f32(ne0, dst_ptr, c);
}
}
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
ggml_compute_forward_fill_f32(params, dst);
}
// ggml_compute_tri
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_TENSOR_UNARY_OP_LOCALS
const auto [ir0, ir1] = get_thread_range(params, src0);
bool (*bipred)(int, int);
switch (ttype) {
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
default: GGML_ABORT("invalid tri type");
}
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
for (int i0 = 0; i0 < ne0; ++i0) {
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
}
}
}
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_tri_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_gelu_erf
static void ggml_compute_forward_gelu_erf_f32(
@@ -8536,7 +8664,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = ggml_softplus(dt[h]);
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
const float dA = expf(dt_soft_plus * A[h]);
const int g = h / (nh / ng); // repeat_interleave
@@ -8633,7 +8761,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = ggml_softplus(dt[h]);
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
const int g = h / (nh / ng); // repeat_interleave
// dim
@@ -8916,6 +9044,14 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_xielu(params, dst);
} break;
case GGML_UNARY_OP_EXPM1:
{
ggml_compute_forward_expm1(params, dst);
} break;
case GGML_UNARY_OP_SOFTPLUS:
{
ggml_compute_forward_softplus(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
@@ -9512,6 +9648,76 @@ void ggml_compute_forward_gla(
}
}
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
GGML_TENSOR_BINARY_OP_LOCALS;
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ne00 == ne01); // A must be square
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
const int ith = params->ith;
const int nth = params->nth;
const int64_t k = ne10; // number of RHS columns
const int64_t n = ne11; // A is n×n
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
// chunks per thread
const int64_t dr = (nr + nth - 1)/nth;
// chunk range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
const float * A = (const float *) src0->data; // [n, n, B1, B2]
const float * B = (const float *) src1->data; // [n, k, B1, B2]
float * X = ( float *) dst->data; // [n, k, B1, B2]
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*k);
const int64_t i02 = (ir - i03*ne02*k)/k;
const int64_t i01 = (ir - i03*ne02*k - i02*k);
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
for (int64_t i00 = 0; i00 < n; ++i00) {
float sum = 0.0f;
for (int64_t t = 0; t < i00; ++t) {
sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
}
const float diag = A_batch[i00 * n + i00];
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
}
}
}
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_compute_forward_solve_tri_f32(params, dst);
} else {
GGML_ABORT("fatal error");
}
}
// ggml_compute_forward_rwkv_wkv7
static void ggml_compute_forward_rwkv_wkv7_f32(
+4
View File
@@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -81,6 +82,8 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_back(
const struct ggml_compute_params * params,
@@ -96,6 +99,7 @@ void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params,
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+95 -42
View File
@@ -1600,29 +1600,52 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
return false;
}
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
void forward_mul_mat_one_chunk(ggml_compute_params * params,
ggml_tensor * op,
int64_t src0_start,
int64_t src0_end,
int64_t src1_start,
int64_t src1_end) {
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
ggml_tensor * dst = op;
GGML_TENSOR_BINARY_OP_LOCALS
const void * src1_wdata = params->wdata;
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
GGML_ASSERT(ne03 == 1 && ne13 == 1);
GGML_ASSERT(ne12 % ne02 == 0);
const int64_t r2 = ne12 / ne02;
const int64_t i12 = src1_start / ne1;
const int64_t i11 = src1_start - i12 * ne1;
// Determine batch index
const int64_t i02 = i12 / r2;
const int64_t i1 = i11;
const int64_t i2 = i12;
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
const int64_t nrows = src1_end - src1_start;
const int64_t ncols = src0_end - src0_start;
GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (ne11 > 3) {
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
if (nrows > 3) {
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
src0_ptr + src0_start * nb01, src1_ptr,
nrows - (nrows % 4), ncols);
}
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
ne01, src0_ptr + src0_start * nb01,
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
}
}
@@ -1647,6 +1670,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// TODO: General batched mul mat for 4D tensors
// Currently only supports 3D tensors
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
GGML_ASSERT(ne3 == 1);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
@@ -1654,47 +1683,64 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
char * wdata = static_cast<char *>(params->wdata);
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
const size_t nbw2 = nbw1 * ne11;
assert(params->wsize >= nbw1 * ne11);
assert(params->wsize >= nbw2 * ne12);
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
int64_t i11_processed = 0;
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
}
// INFO: Quantization is done in planes to avoid extra complexity in chunking.
// Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
// the planes are broadcast.
for (int64_t i12 = 0; i12 < ne12; i12++) {
char * data_ptr = (char *) src1->data + i12 * nb12;
char * wdata_ptr = wdata + i12 * nbw2;
i11_processed = ne11 - ne11 % 4;
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
}
const int64_t i11_processed = ne11 - ne11 % 4;
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
}
}
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
// 4x chunks per thread
int64_t nr = ggml_nrows(op->src[0]);
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
const int64_t nr0 = ggml_nrows(op->src[0]);
int nth_scaled = nth * 4;
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
// src1 is chunked only by full planes.
// When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
// to route them thorugh GEMV.
// nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
// to avoid affecting their performance
int64_t nchunk1 = ne12;
// Ensure minimum chunk size to avoid alignment issues with high thread counts
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
const int64_t min_chunk_size = NB_COLS;
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
}
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
if (nth == 1 || nchunk0 < nth || disable_chunking) {
nchunk0 = nth;
}
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
// This prevents creating too many tiny chunks that could overlap after alignment
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
if (nchunk > max_nchunk) {
nchunk = max_nchunk;
}
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
nchunk0 = MIN(nchunk0, max_nchunk);
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
@@ -1706,23 +1752,30 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
while (current_chunk < nchunk) {
int64_t src0_start = (current_chunk * ne01) / nchunk;
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
int64_t src0_start = dr0 * ith0;
int64_t src0_end = MIN(src0_start + dr0, nr0);
// full-plane range for src1
int64_t src1_start = ith1 * ne11;
int64_t src1_end = (ith1 + 1) * ne11;
// Align boundaries to NB_COLS - round up to ensure all data is included
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
if (src0_end > ne01) {
src0_end = ne01;
}
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
src0_end = MIN(src0_end, ne01);
// Make sure current plane is the last one before exiting
if (src0_start >= src0_end) {
break;
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
continue;
}
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
+16
View File
@@ -73,6 +73,14 @@ static inline float op_log(float x) {
return logf(x);
}
static inline float op_expm1(float x) {
return expf(x) - 1.0f;
}
static inline float op_softplus(float x) {
return (x > 20.0f) ? x : logf(1.0f + expf(x));
}
static inline float op_floor(float x) {
return floorf(x);
}
@@ -290,6 +298,14 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
unary_op<op_log>(params, dst);
}
void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_expm1>(params, dst);
}
void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_softplus>(params, dst);
}
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_floor>(params, dst);
}
+2
View File
@@ -22,6 +22,8 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+10
View File
@@ -1416,6 +1416,16 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
#endif
}
inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
if (i == 0) {
y[i] = x[i];
} else {
y[i] = y[i - 1] + x[i];
}
}
}
inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
ggml_float sum = 0.0;
for (int i = 0; i < n; ++i) {
+8
View File
@@ -2527,6 +2527,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_TRUNC:
ggml_cuda_op_trunc(ctx, dst);
break;
case GGML_UNARY_OP_EXPM1:
ggml_cuda_op_expm1(ctx, dst);
break;
case GGML_UNARY_OP_SOFTPLUS:
ggml_cuda_op_softplus(ctx, dst);
break;
default:
return false;
}
@@ -3829,6 +3835,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_EXPM1:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
+16
View File
@@ -81,6 +81,14 @@ static __device__ __forceinline__ float op_log(float x) {
return logf(x);
}
static __device__ __forceinline__ float op_expm1(float x) {
return expm1f(x);
}
static __device__ __forceinline__ float op_softplus(float x) {
return (x > 20.0f) ? x : logf(1.0f + expf(x));
}
static __device__ __forceinline__ float op_elu(float x) {
return (x > 0.f) ? x : expm1f(x);
}
@@ -233,6 +241,14 @@ void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_trunc>(ctx, dst);
}
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_expm1>(ctx, dst);
}
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_softplus>(ctx, dst);
}
/* gated ops */
template <float (*op)(float), typename T>
+4
View File
@@ -61,6 +61,10 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+1 -1
View File
@@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) {
}
}
static inline float ggml_softplus(float input) {
static inline float ggml_compute_softplus_f32(float input) {
return (input > 20.0f) ? input : logf(1 + expf(input));
}
//
+28
View File
@@ -943,6 +943,34 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_ARGSORT);
char base[256];
char name[256];
ggml_sort_order order = (ggml_sort_order) op->op_params[0];
const char * order_str = "undefined";
switch (order) {
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
+1
View File
@@ -125,6 +125,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
-2
View File
@@ -904,8 +904,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
return op->src[0]->ne[0] <= 1024;
case GGML_OP_ARANGE:
return true;
case GGML_OP_FLASH_ATTN_EXT:
+20 -2
View File
@@ -793,10 +793,28 @@ typedef struct {
} ggml_metal_kargs_leaky_relu;
typedef struct {
int64_t ncols;
int64_t ncols_pad;
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
} ggml_metal_kargs_argsort;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t len;
} ggml_metal_kargs_argsort_merge;
typedef struct {
int64_t ne0;
float start;
+84 -20
View File
@@ -1975,7 +1975,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
const bool has_mask = op->src[3] != nullptr;
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
// note: always reserve the padding space to avoid graph reallocations
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
const bool has_kvpad = true;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
@@ -1984,7 +1986,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
} else {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
const bool has_kvpad = true;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_NCPSG*(
@@ -2020,9 +2023,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
// this optimization is not useful for the vector kernels
if (is_vec) {
return res;
}
// note: always reserve the blk buffer to avoid graph reallocations
//if (is_vec) {
// return res;
//}
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
@@ -2049,13 +2053,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
size_t res = 0;
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
// note: always reserve the temp buffer to avoid graph reallocations
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
if (true) {
const int64_t nwg = 32;
const int64_t ne01_max = std::min(ne01, 32);
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
}
return res;
@@ -3523,38 +3530,95 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
// bitonic sort requires the number of elements to be power of 2
int64_t ne00_padded = 1;
while (ne00_padded < ne00) {
ne00_padded *= 2;
}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
const int64_t nrows = ggml_nrows(op->src[0]);
// bitonic sort requires the number of elements to be power of 2
int nth = 1;
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
const int nptg = (ne00 + nth - 1)/nth;
// Metal kernels require the buffer size to be multiple of 16 bytes
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += ggml_nbytes(op);
if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
std::swap(bid_dst, bid_tmp);
}
ggml_metal_kargs_argsort args = {
/*.ncols =*/ ne00,
/*.ncols_pad =*/ ne00_padded
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
int len = nth;
while (len < ne00) {
ggml_metal_op_concurrency_reset(ctx);
ggml_metal_kargs_argsort_merge args_merge = {
.ne00 = ne00,
.ne01 = ne01,
.ne02 = ne02,
.ne03 = ne03,
.nb00 = nb00,
.nb01 = nb01,
.nb02 = nb02,
.nb03 = nb03,
.len = len,
};
// merges per row
const int nm = (ne00 + 2*len - 1) / (2*len);
const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
std::swap(bid_dst, bid_tmp);
len <<= 1;
}
return 1;
}
+4
View File
@@ -197,6 +197,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
} break;
case GGML_OP_ARGSORT:
{
res *= 2;
} break;
default:
break;
}
+137 -27
View File
@@ -4541,69 +4541,179 @@ kernel void kernel_timestep_embedding_f32(
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
constant ggml_metal_kargs_argsort & args,
device const float * x,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_f32_i32(
constant ggml_metal_kargs_argsort & args,
device const float * x,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
int col = tpitg[0];
int row = tgpig[1];
const int col = tpitg[0];
if (col >= args.ncols_pad) return;
const int i00 = (tgpig[0]/args.ne01)*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * x_row = x + row * args.ncols;
threadgroup int32_t * dst_row = shared_values;
device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
// initialize indices
dst_row[col] = col;
smem_i32[col] = i00 + col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= args.ncols_pad; k *= 2) {
for (int k = 2; k <= ntg.x; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= args.ncols ||
(dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
if (smem_i32[col] >= args.ne00 ||
(smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
SWAP(smem_i32[col], smem_i32[ixj]);
}
} else {
if (dst_row[ixj] >= args.ncols ||
(dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
if (smem_i32[ixj] >= args.ne00 ||
(smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
SWAP(smem_i32[col], smem_i32[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// copy the result to dst without the padding
if (col < args.ncols) {
dst[row * args.ncols + col] = dst_row[col];
if (i00 + col < args.ne00) {
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
dst[col] = smem_i32[col];
}
}
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
typedef void (argsort_merge_t)(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_merge_f32_i32(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int im = tgpig[0] / args.ne01;
int i01 = tgpig[0] % args.ne01;
int i02 = tgpig[1];
int i03 = tgpig[2];
const int start = im * (2*args.len);
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start
+ i01*args.ne00
+ i02*args.ne00*args.ne01
+ i03*args.ne00*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len;
dst += start
+ i01*args.ne00
+ i02*args.ne00*args.ne01
+ i03*args.ne00*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0
+ args.nb01*i01
+ args.nb02*i02
+ args.nb03*i03);
for (int k = tpitg.x; k < (int) total; k += ntg.x) {
// find partition (i,j) such that i+j = k
int low = k > len1 ? k - len1 : 0;
int high = MIN(k, len0);
while (low < high) {
const int mid = (low + high) >> 1;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k - mid - 1];
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
if (order == GGML_SORT_ORDER_ASC) {
if (val0 <= val1) {
low = mid + 1;
} else {
high = mid;
}
} else {
if (val0 >= val1) {
low = mid + 1;
} else {
high = mid;
}
}
}
const int i = low;
const int j = k - i;
int32_t out_idx;
if (i >= len0) {
out_idx = tmp1[j];
} else if (j >= len1) {
out_idx = tmp0[i];
} else {
const int32_t idx0 = tmp0[i];
const int32_t idx1 = tmp1[j];
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
out_idx = (order == GGML_SORT_ORDER_ASC)
? (val0 <= val1 ? idx0 : idx1)
: (val0 >= val1 ? idx0 : idx1);
}
dst[k] = out_idx;
}
}
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
kernel void kernel_leaky_relu_f32(
constant ggml_metal_kargs_leaky_relu & args,
device const float * src0,
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,21 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(abs(float(data_a[i])));
}
@@ -7,6 +7,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable
#include "types.glsl"
#include "flash_attn_base.glsl"
@@ -108,6 +109,38 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, m);
} else {
masksh[c][r] = float(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
@@ -153,21 +186,6 @@ void main() {
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);
}
}
}
barrier();
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
@@ -7,6 +7,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
@@ -148,6 +149,37 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
float mask_cache[Bc * Br / WorkGroupSize];
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
mask_cache[idx / WorkGroupSize] = m;
max_mask = max(max_mask, m);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
@@ -208,7 +240,8 @@ void main() {
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
float f = mask_cache[idx / WorkGroupSize];
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
}
}
}
@@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return max(x, y);
}
float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
return max(x, y);
}
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return x;
}
@@ -142,6 +146,44 @@ void main() {
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
if (nem1_bounds_check) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
// skip the block if the mask is entirely -inf
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
@@ -158,31 +200,7 @@ void main() {
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
if (nem1_bounds_check) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
// Clear padding elements to -inf, so they don't contribute to rowmax
@@ -0,0 +1,20 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(-float(data_a[i]));
}
@@ -76,7 +76,7 @@ enum MatMulIdType {
namespace {
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
HANDLE stderr_read, stderr_write;
@@ -99,8 +99,10 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
si.hStdOutput = stdout_write;
si.hStdError = stderr_write;
std::vector<char> cmd(command.begin(), command.end());
cmd.push_back('\0');
std::string cmd;
for (const auto& part : command) {
cmd += part + " ";
}
if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
throw std::runtime_error("Failed to create process");
@@ -138,6 +140,12 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
throw std::runtime_error("Failed to fork process");
}
std::vector<char*> argv;
for (std::string& part : command) {
argv.push_back(part.data());
}
argv.push_back(nullptr);
if (pid == 0) {
close(stdout_pipe[0]);
close(stderr_pipe[0]);
@@ -145,7 +153,7 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
dup2(stderr_pipe[1], STDERR_FILENO);
close(stdout_pipe[1]);
close(stderr_pipe[1]);
execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
execvp(argv[0], argv.data());
_exit(EXIT_FAILURE);
} else {
close(stdout_pipe[1]);
@@ -316,21 +324,27 @@ compile_count_guard acquire_compile_slot() {
void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map<std::string, std::string> defines, bool coopmat, bool dep_file, compile_count_guard slot) {
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
#ifdef _WIN32
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
#else
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path};
#endif
// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos || name.find("rope") != std::string::npos) ? "" : "-O";
#ifdef _WIN32
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
#else
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_path};
#endif
if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
cmd.push_back("-O");
}
if (dep_file) {
cmd.push_back("-MD");
cmd.push_back("-MF");
#ifdef _WIN32
cmd.push_back("\"" + target_cpp + ".d\"");
#else
cmd.push_back(target_cpp + ".d");
#endif
}
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
@@ -354,9 +368,13 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// }
// std::cout << std::endl;
execute_command(command, stdout_str, stderr_str);
execute_command(cmd, stdout_str, stderr_str);
if (!stderr_str.empty()) {
std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
std::cerr << "cannot compile " << name << "\n\n";
for (const auto& part : cmd) {
std::cerr << part << " ";
}
std::cerr << "\n\n" << stderr_str << std::endl;
return;
}
@@ -430,7 +448,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
if (f16acc) {
base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}
if (coopmat) {
@@ -610,7 +628,7 @@ void process_shaders() {
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
if (f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}
for (const auto& tname : type_names) {
@@ -809,6 +827,8 @@ void process_shaders() {
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("neg_f16", "neg.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("neg_f32", "neg.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
@@ -817,6 +837,8 @@ void process_shaders() {
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : "";
@@ -1081,11 +1103,6 @@ int main(int argc, char** argv) {
if (args.find("--glslc") != args.end()) {
GLSLC = args["--glslc"]; // Path to glslc
if (!std::filesystem::exists(GLSLC) || !std::filesystem::is_regular_file(GLSLC)) {
std::cerr << "Error: glslc not found at " << GLSLC << std::endl;
return EXIT_FAILURE;
}
}
if (args.find("--source") != args.end()) {
input_filepath = args["--source"]; // The shader source file to compile
+154 -5
View File
@@ -935,6 +935,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"COS",
"SUM",
"SUM_ROWS",
"CUMSUM",
"MEAN",
"ARGMAX",
"COUNT_EQUAL",
@@ -990,6 +991,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TIMESTEP_EMBEDDING",
"ARGSORT",
"LEAKY_RELU",
"TRI",
"FILL",
"FLASH_ATTN_EXT",
"FLASH_ATTN_BACK",
@@ -1002,6 +1005,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"RWKV_WKV6",
"GATED_LINEAR_ATTN",
"RWKV_WKV7",
"SOLVE_TRI",
"UNARY",
@@ -1019,7 +1023,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1039,6 +1043,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cos(x)",
"Σx",
"Σx_k",
"cumsum(x)",
"Σx/n",
"argmax(x)",
"count_equal(x)",
@@ -1094,6 +1099,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"leaky_relu(x)",
"tri(x)",
"fill(x, c)",
"flash_attn_ext(x)",
"flash_attn_back(x)",
@@ -1106,6 +1113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"rwkv_wkv6(k, v, r, tf, td, s)",
"gated_linear_attn(k, v, q, gate, s)",
"rwkv_wkv7(r, w, k, v, a, b, s)",
"A X = B, A triangular, solve X",
"unary(x)",
@@ -1123,7 +1131,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -1142,6 +1150,8 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSWISH",
"HARDSIGMOID",
"EXP",
"EXPM1",
"SOFTPLUS",
"GELU_ERF",
"XIELU",
"FLOOR",
@@ -1150,7 +1160,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"TRUNC",
};
static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20");
static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22");
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
@@ -2258,6 +2268,30 @@ struct ggml_tensor * ggml_log_inplace(
return ggml_log_impl(ctx, a, true);
}
struct ggml_tensor * ggml_expm1(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_EXPM1);
}
struct ggml_tensor * ggml_expm1_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXPM1);
}
struct ggml_tensor * ggml_softplus(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS);
}
struct ggml_tensor * ggml_softplus_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS);
}
// ggml_sin
static struct ggml_tensor * ggml_sin_impl(
@@ -2341,6 +2375,21 @@ struct ggml_tensor * ggml_sum_rows(
return result;
}
// ggml_cumsum
struct ggml_tensor * ggml_cumsum(
struct ggml_context * ctx,
struct ggml_tensor * a) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_CUMSUM;
result->src[0] = a;
return result;
}
// ggml_mean
struct ggml_tensor * ggml_mean(
@@ -2668,8 +2717,8 @@ struct ggml_tensor * ggml_xielu(
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
ggml_set_op_params_f32(result, 1, beta + ggml_compute_softplus_f32(alpha_n));
ggml_set_op_params_f32(result, 2, ggml_compute_softplus_f32(alpha_p));
ggml_set_op_params_f32(result, 3, beta);
ggml_set_op_params_f32(result, 4, eps);
@@ -5028,6 +5077,61 @@ struct ggml_tensor * ggml_timestep_embedding(
return result;
}
// ggml_tri
struct ggml_tensor * ggml_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_tri_type type) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(a->ne[0] == a->ne[1]);
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
ggml_set_op_params_i32(result, 0, type);
result->op = GGML_OP_TRI;
result->src[0] = a;
return result;
}
// ggml_fill
static struct ggml_tensor * ggml_fill_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c,
bool inplace) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(a));
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params_f32(result, 0, c);
result->op = GGML_OP_FILL;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_fill(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c) {
return ggml_fill_impl(ctx, a, c, false);
}
struct ggml_tensor * ggml_fill_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c) {
return ggml_fill_impl(ctx, a, c, true);
}
// ggml_argsort
struct ggml_tensor * ggml_argsort(
@@ -5882,6 +5986,41 @@ struct ggml_tensor * ggml_opt_step_sgd(
return result;
}
// solve_tri
struct ggml_tensor * ggml_solve_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
bool left,
bool lower,
bool uni) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(b->type == GGML_TYPE_F32);
// A must be square and lower diagonal
GGML_ASSERT(a->ne[0] == a->ne[1]);
// B must have same outer dimension as A
GGML_ASSERT(a->ne[1] == b->ne[1]);
// batch dimensions must be equal
GGML_ASSERT(a->ne[2] == b->ne[2]);
GGML_ASSERT(a->ne[3] == b->ne[3]);
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(ggml_is_contiguous(b));
GGML_ASSERT(lower && left && !uni); // TODO: support other variants
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
result->op = GGML_OP_SOLVE_TRI;
result->src[0] = a;
result->src[1] = b;
return result;
}
////////////////////////////////////////////////////////////////////////////////
struct ggml_hash_set ggml_hash_set_new(size_t size) {
@@ -6454,6 +6593,16 @@ static void ggml_compute_backward(
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
}
} break;
case GGML_UNARY_OP_EXPM1: {
if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_exp(ctx, src0)));
}
} break;
case GGML_UNARY_OP_SOFTPLUS: {
if (src0_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sigmoid(ctx, src0)));
}
} break;
default: {
fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
__func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
+31
View File
@@ -409,6 +409,7 @@ class MODEL_ARCH(IntEnum):
BAILINGMOE2 = auto()
DOTS1 = auto()
ARCEE = auto()
AFMOE = auto()
ERNIE4_5 = auto()
ERNIE4_5_MOE = auto()
HUNYUAN_MOE = auto()
@@ -464,6 +465,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_POST_NORM = auto()
ATTN_ROT_EMBD = auto()
ATTN_SINKS = auto()
ATTN_GATE = auto()
FFN_GATE_INP = auto()
FFN_GATE_INP_SHEXP = auto()
FFN_NORM = auto()
@@ -776,6 +778,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.AFMOE: "afmoe",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
MODEL_ARCH.FALCON_H1: "falcon-h1",
@@ -828,6 +831,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
@@ -2693,6 +2697,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.AFMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.ERNIE4_5: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
+8 -1
View File
@@ -314,6 +314,10 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.sinks", # openai-moe
),
MODEL_TENSOR.ATTN_GATE: (
"model.layers.{bid}.self_attn.gate_proj", # afmoe
),
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
@@ -340,11 +344,12 @@ class TensorNameMap:
"model.layers.{bid}.feedforward_layernorm", # apertus
),
# Post feed-forward norm
# Pre feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.pre_ff_layernorm.weight",
"model.layers.{bid}.pre_mlp_layernorm", # afmoe
),
# Post feed-forward norm
@@ -370,6 +375,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate.wg", # hunyuan
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
"model.layers.{bid}.feed_forward.gate", # lfm2moe
"model.layers.{bid}.mlp.router.gate", # afmoe
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -380,6 +386,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
"model.layers.{bid}.mlp.expert_bias", # afmoe
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
),
+1
View File
@@ -35,6 +35,7 @@ add_library(llama
unicode-data.cpp
unicode.cpp
unicode.h
models/afmoe.cpp
models/apertus.cpp
models/arcee.cpp
models/arctic.cpp
+32
View File
@@ -90,6 +90,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_AFMOE, "afmoe" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
@@ -333,6 +334,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_AFMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_LLAMA4,
{
@@ -2444,6 +2475,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+2
View File
@@ -94,6 +94,7 @@ enum llm_arch {
LLM_ARCH_BAILINGMOE2,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_AFMOE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
@@ -312,6 +313,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
+102
View File
@@ -84,6 +84,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_15B: return "15B";
case LLM_TYPE_16B: return "16B";
case LLM_TYPE_20B: return "20B";
case LLM_TYPE_26B: return "26B";
case LLM_TYPE_27B: return "27B";
case LLM_TYPE_30B: return "30B";
case LLM_TYPE_32B: return "32B";
@@ -695,6 +696,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_AFMOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
// Set up interleaved sliding window attention (ISWA)
// Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4)
if (hparams.n_swa > 0) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(4);
} else {
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
}
// Default to sigmoid if not set
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
}
switch (hparams.n_layer) {
case 56: type = LLM_TYPE_6B; break;
case 32: type = LLM_TYPE_26B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_DECI:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5749,6 +5781,71 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_AFMOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
const int64_t n_ff_exp = hparams.n_ff_exp;
const int64_t n_expert_shared = hparams.n_expert_shared;
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// dual attention normalization
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
// attention projections
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
// Q/K normalization
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
// attention gating
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
// dual ffn normalization
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) {
// MoE layers
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
// grouped expert weights
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
// shared expert
if (n_expert_shared > 0) {
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
}
} else {
// Dense layers
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
}
} break;
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
{
@@ -7243,6 +7340,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_arcee>(*this, params);
} break;
case LLM_ARCH_AFMOE:
{
llm = std::make_unique<llm_build_afmoe>(*this, params);
} break;
case LLM_ARCH_ERNIE4_5:
{
llm = std::make_unique<llm_build_ernie4_5>(*this, params);
@@ -7528,6 +7629,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_MINIMAX_M2:
case LLM_ARCH_COGVLM:
case LLM_ARCH_PANGU_EMBED:
case LLM_ARCH_AFMOE:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
+2
View File
@@ -76,6 +76,7 @@ enum llm_type {
LLM_TYPE_15B,
LLM_TYPE_16B,
LLM_TYPE_20B,
LLM_TYPE_26B,
LLM_TYPE_27B,
LLM_TYPE_30B,
LLM_TYPE_32B,
@@ -234,6 +235,7 @@ struct llama_layer {
struct ggml_tensor * wk_enc = nullptr;
struct ggml_tensor * wv_enc = nullptr;
struct ggml_tensor * wo_enc = nullptr;
struct ggml_tensor * wqkv_gate = nullptr;
// attention bias
struct ggml_tensor * bq = nullptr;
+10 -5
View File
@@ -4,6 +4,7 @@
#include "llama-vocab.h"
#include "llama-grammar.h"
#include <array>
#include <algorithm>
#include <cassert>
#include <cfloat>
@@ -1625,10 +1626,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
auto * ctx = new llama_sampler_grammar;
if (grammar_str != nullptr && grammar_str[0] != '\0') {
std::string trigger_pattern;
llama_grammar * grammar = nullptr;
// TODO: remove trigger_words support.
if (trigger_words != nullptr && num_trigger_words > 0) {
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
std::string trigger_pattern("[\\s\\S]*?(");
trigger_pattern = "[\\s\\S]*?(";
for (size_t i = 0; i < num_trigger_words; ++i) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
if (i > 0) {
@@ -1637,15 +1640,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
}
trigger_pattern += ")[\\s\\S]*";
const auto * trigger_pattern_c = trigger_pattern.c_str();
trigger_patterns = &trigger_pattern_c;
num_trigger_patterns = 1;
std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
} else {
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
}
*ctx = {
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
/* .grammar = */ grammar,
};
if (!ctx->grammar) {
delete ctx;
+15
View File
@@ -443,6 +443,17 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_AFMOE:
regex_exprs = {
// Digit handling - uses custom implementation in unicode.cpp
// Groups digits with leading 1-2 based on total length modulo 3
"\\p{AFMoE_digits}",
// CJK and Asian scripts (using direct Unicode literals)
"[一-鿿㐀-䶿豈-﫿぀-ゟ゠-ヿ・-゚⼀-⿟เ-๿຀-໿ក-៿က-႟ꩠ-ꩿꧠ-꧿가-힯ᄀ-ᇿ]+",
// Main BPE pattern
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@@ -1993,6 +2004,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "grok-2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
clean_spaces = false;
} else if (
tokenizer_pre == "afmoe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_AFMOE;
clean_spaces = false;
} else if (
tokenizer_pre == "minimax-m2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
+1
View File
@@ -50,6 +50,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
};
struct LLM_KV;
+187
View File
@@ -0,0 +1,187 @@
#include "models.h"
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// MuP scaling: embeddings * sqrt(hidden_size)
// mup_enabled = true, hidden_size = 1024, scale = 32.0
inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd)));
cb(inpL, "inp_embd_scaled", -1);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// dual attention normalization (pre)
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
ggml_tensor * attn_inp = cur; // save input for gate computation
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
// compute gate from input
ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp);
cb(gate, "attn_gate_proj", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
// Q/K normalization
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
cb(Kcur, "Kcur_normed", il);
// RoPE only for sliding_attention layers
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
((il + 1) % hparams.n_no_rope_layer_step) != 0;
if (use_rope) {
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_rope", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur_rope", il);
}
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
cur = build_attn(inp_attn,
NULL, NULL, // wo will be applied after gating
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
// attention gating: attn_out * sigmoid(gate) BEFORE o_proj
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "attn_gate_sig", il);
cur = ggml_mul(ctx0, cur, gate);
cb(cur, "attn_gated", il);
// now apply output projection
cur = build_lora_mm(model.layers[il].wo, cur);
cb(cur, "attn_o_proj", il);
}
// dual attention normalization (post)
cur = build_norm(cur,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// dual ffn normalization (pre)
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
// MoE or dense FFN
if ((uint32_t)il >= hparams.n_layer_dense_lead) {
// MoE layer with sigmoid routing, normalization, and scaling
ggml_tensor * moe_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU,
hparams.expert_weights_norm, // norm_w (route_norm=True)
hparams.expert_weights_scale, // scale_w
hparams.expert_weights_scale, // w_scale (route_scale=2.826)
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(moe_out, "ffn_moe_out", il);
// shared expert
if (hparams.n_expert_shared > 0) {
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
} else {
cur = moe_out;
}
} else {
// dense layer
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
}
// dual ffn normalization (post)
cur = build_norm(cur,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_post_norm", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
+4
View File
@@ -57,6 +57,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
int il) const;
};
struct llm_build_afmoe : public llm_graph_context {
llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_apertus : public llm_graph_context {
llm_build_apertus(const llama_model & model, const llm_graph_params & params);
};
+77
View File
@@ -729,6 +729,80 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
return bpe_offsets;
}
// AFMOE digit handling: splits digits with leading 1-2 based on total length modulo 3
static std::vector<size_t> unicode_regex_split_custom_afmoe(const std::string & text, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
bpe_offsets.reserve(offsets.size());
const auto cpts = unicode_cpts_from_utf8(text);
size_t start = 0;
for (auto offset : offsets) {
const size_t offset_ini = start;
const size_t offset_end = start + offset;
assert(offset_end <= cpts.size());
start = offset_end;
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};
size_t _prev_end = offset_ini;
auto _add_token = [&] (const size_t end) -> size_t {
assert(_prev_end <= end && end <= offset_end);
size_t len = end - _prev_end;
if (len > 0) {
bpe_offsets.push_back(len);
}
_prev_end = end;
return len;
};
for (size_t pos = offset_ini; pos < offset_end; ) {
const auto flags = _get_flags(pos);
// Handle digit sequences with special splitting logic
if (flags.is_number) {
size_t digit_start = pos;
size_t digit_count = 0;
// Count consecutive digits
while (_get_flags(pos).is_number && pos < offset_end) {
digit_count++;
pos++;
}
// Split based on total length modulo 3
size_t remainder = digit_count % 3;
size_t current = digit_start;
// Emit leading 1-2 digits if needed
if (remainder > 0) {
_add_token(current + remainder);
current += remainder;
}
// Emit groups of 3
while (current < digit_start + digit_count) {
_add_token(current + 3);
current += 3;
}
continue;
}
// For non-digits, just move forward
pos++;
}
// Add any remaining content
if (_prev_end < offset_end) {
_add_token(offset_end);
}
}
return bpe_offsets;
}
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
std::vector<size_t> bpe_offsets;
@@ -742,6 +816,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
} else if (regex_expr == "\\p{Han}+") {
// K2's first pattern - handle all K2 patterns together
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
} else if (regex_expr == "\\p{AFMoE_digits}") {
// AFMOE digit pattern - use custom implementation for proper splitting
bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
}
return bpe_offsets;
+230 -3
View File
@@ -175,6 +175,38 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m
ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t));
}
// generate a lower triangular matrix
static void init_tensor_tril(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
GGML_ASSERT(tensor->ne[0] == tensor->ne[1]);
GGML_TENSOR_LOCALS(int32_t, ne, tensor, ne);
GGML_TENSOR_LOCALS(size_t, nb, tensor, nb);
std::vector<float> data_f32(ne0*ne1*ne2*ne3);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(min, max);
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
for (int64_t i1 = 0; i1 < ne1; i1++) {
for (int64_t i0 = 0; i0 < ne0; i0++) {
int64_t idx = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3) / sizeof(float);
if (i0 <= i1) {
data_f32[idx] = dis(gen);
} else {
data_f32[idx] = 0.0f;
}
}
}
}
}
ggml_backend_tensor_set(tensor, data_f32.data(), 0, ggml_nbytes(tensor));
}
static std::vector<float> tensor_to_float(const ggml_tensor * t) {
std::vector<float> tv;
tv.reserve(ggml_nelements(t));
@@ -1804,7 +1836,8 @@ struct test_unary : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override {
const bool grad_supported = op == GGML_UNARY_OP_ABS || op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_NEG ||
op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU;
op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU ||
op == GGML_UNARY_OP_EXPM1 || op == GGML_UNARY_OP_SOFTPLUS;
ggml_tensor * a;
if (v & 1) {
@@ -2779,7 +2812,7 @@ struct test_bin_bcast : public test_case {
const std::array<int, 4> nr;
int nf; // number of fused ops, nf == 1 -> single op (no fusion)
bool run_whole_graph() override { return true; }
bool run_whole_graph() override { return nf > 1; }
std::string vars() override {
return VARS_TO_STR4(type, ne, nr, nf);
@@ -5395,6 +5428,7 @@ struct test_pad : public test_case {
}
};
// GGML_OP_PAD (with extension)
struct test_pad_ext : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne_a;
@@ -5802,6 +5836,7 @@ struct test_opt_step_adamw : public test_case {
}
};
// GGML_OP_OPT_STEP_SGD
struct test_opt_step_sgd : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
@@ -5841,6 +5876,170 @@ struct test_opt_step_sgd : public test_case {
}
};
// GGML_OP_CUMSUM
struct test_cumsum : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override { return VARS_TO_STR2(type, ne); }
test_cumsum(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_cumsum(ctx, a);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -1.0f, 1.0f);
}
}
};
// GGML_OP_XIELU
struct test_xielu : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override { return VARS_TO_STR2(type, ne); }
test_xielu(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_param(a);
ggml_set_name(a, "a");
float alpha_n = 4.0f;
float alpha_p = 20.0f;
float beta = 0.5f;
float eps = 0.0000001f;
ggml_tensor * out = ggml_xielu(ctx, a, alpha_n, alpha_p, beta, eps);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -1.0f, 1.0f);
}
}
};
// GGML_OP_TRI
struct test_tri : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const ggml_tri_type tri_type;
std::string vars() override { return VARS_TO_STR3(type, ne, tri_type); }
test_tri(ggml_tri_type tri_type, ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = { 10, 10, 4, 3 })
: type(type), ne(ne), tri_type(tri_type) {
GGML_ASSERT(ne[0] == ne[1]);
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_tri(ctx, a, tri_type);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -1.0f, 1.0f);
}
}
};
// GGML_OP_FILL
struct test_fill : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float c;
std::string vars() override { return VARS_TO_STR3(type, ne, c); }
test_fill(float c, ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = { 10, 10, 4, 3 })
: type(type), ne(ne), c(c) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_fill(ctx, a, c);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_SOLVE_TRI
struct test_solve_tri : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne_lhs;
const std::array<int64_t, 4> ne_rhs;
std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }
test_solve_tri(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },
std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }
)
: type(type), ne_lhs(ne_lhs), ne_rhs(ne_rhs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne_lhs[0], ne_lhs[1], ne_lhs[2], ne_lhs[3]);
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne_rhs[0], ne_rhs[1], ne_rhs[2], ne_rhs[3]);
ggml_set_param(b);
ggml_set_name(b, "b");
ggml_tensor * out = ggml_solve_tri(ctx, a, b, true, true, false);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (strcmp(t->name, "a") == 0) {
// note: avoid zeros in the diagonal
init_tensor_tril(t, 0.1, 1.0f);
} else {
init_tensor_uniform(t, -1.0f, 1.0f);
}
}
}
};
enum llm_norm_type {
LLM_NORM,
LLM_NORM_RMS,
@@ -6282,6 +6481,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
if (op == GGML_UNARY_OP_XIELU) {
continue; // need extra params, separate test
}
test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
}
@@ -7290,8 +7492,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
}
@@ -7339,6 +7546,26 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_arange());
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_cumsum());
test_cases.emplace_back(new test_xielu());
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG));
test_cases.emplace_back(new test_fill(0.0f));
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
test_cases.emplace_back(new test_solve_tri());
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 17, 17, 2, 4 }, { 9, 17, 2, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
for (bool v : {false, true}) {
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
+5 -12
View File
@@ -224,7 +224,6 @@ static void clip_log_callback_default(enum ggml_log_level level, const char * te
}
struct clip_logger_state {
ggml_log_level verbosity_thold;
ggml_log_callback log_callback;
void * log_callback_user_data;
};
@@ -258,17 +257,11 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
va_end(args);
}
#define LOG_TMPL(level, ...) \
do { \
if ((level) >= g_logger_state.verbosity_thold) { \
clip_log_internal((level), __VA_ARGS__); \
} \
} while (0)
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
#define LOG_INF(...) clip_log_internal(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
#define LOG_WRN(...) clip_log_internal(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
#define LOG_ERR(...) clip_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define LOG_DBG(...) clip_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
#define LOG_CNT(...) clip_log_internal(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
//
// cpp wrappers
+1 -3
View File
@@ -24,8 +24,7 @@
#include <array>
#include <functional>
// TODO: allow to pass callback from user code
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
enum ffn_op_type {
FFN_GELU,
@@ -3507,7 +3506,6 @@ struct clip_model_loader {
};
struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
g_logger_state.verbosity_thold = ctx_params.verbosity;
clip_ctx * ctx_vision = nullptr;
clip_ctx * ctx_audio = nullptr;
-1
View File
@@ -31,7 +31,6 @@ enum clip_flash_attn_type {
struct clip_context_params {
bool use_gpu;
enum ggml_log_level verbosity;
enum clip_flash_attn_type flash_attn_type;
int image_min_tokens;
int image_max_tokens;
+1 -1
View File
@@ -135,7 +135,6 @@ struct mtmd_cli_context {
mparams.use_gpu = params.mmproj_use_gpu;
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.flash_attn_type = params.flash_attn_type;
mparams.image_min_tokens = params.image_min_tokens;
mparams.image_max_tokens = params.image_max_tokens;
@@ -277,6 +276,7 @@ int main(int argc, char ** argv) {
}
common_init();
mtmd_helper_log_set(common_log_default_callback, nullptr);
if (params.mmproj.path.empty()) {
show_additional_info(argc, argv);
+60 -3
View File
@@ -32,8 +32,65 @@
#define STB_IMAGE_IMPLEMENTATION
#include "stb/stb_image.h"
#define LOG_INF(...) fprintf(stdout, __VA_ARGS__)
#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__)
//
// internal logging functions
//
struct mtmd_helper_logger {
ggml_log_callback default_callback = [](ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
fputs(text, stderr);
fflush(stderr);
};
ggml_log_callback log_callback = default_callback;
void * log_callback_user_data;
void log_v(enum ggml_log_level level, const char * format, va_list args) {
if (format == NULL) {
return;
}
va_list args_copy;
va_copy(args_copy, args);
char buffer[128];
int len = vsnprintf(buffer, 128, format, args);
if (len < 128) {
log_callback(level, buffer, log_callback_user_data);
} else {
char * buffer2 = (char *) calloc(len + 1, sizeof(char));
vsnprintf(buffer2, len + 1, format, args_copy);
buffer2[len] = 0;
log_callback(level, buffer2, log_callback_user_data);
free(buffer2);
}
va_end(args_copy);
}
void log(enum ggml_log_level level, const char * format, ...) {
va_list args;
va_start(args, format);
log_v(level, format, args);
va_end(args);
}
} g_logger;
#define LOG_INF(...) g_logger.log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
#define LOG_WRN(...) g_logger.log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
#define LOG_ERR(...) g_logger.log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data) {
if (log_callback == nullptr) {
log_callback = g_logger.default_callback;
}
g_logger.log_callback = log_callback;
g_logger.log_callback_user_data = user_data;
mtmd_log_set(log_callback, user_data);
}
//
// helper functions
//
size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
size_t n_tokens = 0;
@@ -325,7 +382,7 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
llama_pos * new_n_past) {
size_t n_chunks = mtmd_input_chunks_size(chunks);
if (n_chunks == 0) {
LOG_ERR("no chunks to eval\n");
LOG_WRN("no chunks to eval\n");
return 0;
}
+5
View File
@@ -20,6 +20,11 @@ extern "C" {
// BREAKING CHANGES are expected.
//
// Set callback for all future logging events.
// If this is not called, or NULL is supplied, everything is output on stderr.
// Note: this also call mtmd_log_set() internally
MTMD_API void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data);
// helper function to construct a mtmd_bitmap from a file
// it calls mtmd_helper_bitmap_init_from_buf() internally
// returns nullptr on failure
+5 -2
View File
@@ -105,7 +105,6 @@ mtmd_context_params mtmd_context_params_default() {
/* use_gpu */ true,
/* print_timings */ true,
/* n_threads */ 4,
/* verbosity */ GGML_LOG_LEVEL_INFO,
/* image_marker */ MTMD_DEFAULT_IMAGE_MARKER,
/* media_marker */ mtmd_default_marker(),
/* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO,
@@ -175,7 +174,6 @@ struct mtmd_context {
clip_context_params ctx_clip_params {
/* use_gpu */ ctx_params.use_gpu,
/* verbosity */ ctx_params.verbosity,
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
/* image_min_tokens */ ctx_params.image_min_tokens,
/* image_max_tokens */ ctx_params.image_max_tokens,
@@ -1096,3 +1094,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
return chunks;
}
void mtmd_log_set(ggml_log_callback log_callback, void * user_data) {
g_logger_state.log_callback = log_callback ? log_callback : clip_log_callback_default;
g_logger_state.log_callback_user_data = user_data;
}
+4 -1
View File
@@ -79,7 +79,6 @@ struct mtmd_context_params {
bool use_gpu;
bool print_timings;
int n_threads;
enum ggml_log_level verbosity;
const char * image_marker; // deprecated, use media_marker instead
const char * media_marker;
enum llama_flash_attn_type flash_attn_type;
@@ -215,6 +214,10 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
// Set callback for all future logging events.
// If this is not called, or NULL is supplied, everything is output on stderr.
MTMD_API void mtmd_log_set(ggml_log_callback log_callback, void * user_data);
/////////////////////////////////////////
// test function, to be used in test-mtmd-c-api.c
Binary file not shown.
+55 -50
View File
@@ -2454,11 +2454,12 @@ struct server_context {
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
mtmd_helper_log_set(common_log_default_callback, nullptr);
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mparams.flash_attn_type = params_base.flash_attn_type;
mparams.image_min_tokens = params_base.image_min_tokens;
mparams.image_max_tokens = params_base.image_max_tokens;
@@ -3591,13 +3592,13 @@ struct server_context {
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
if (!slot.is_processing()) {
continue;
}
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
if (!slot_batched) {
slot_batched = &slot;
} else if (!slot_batched->can_batch_with(slot)) {
continue;
}
if (slot_batched && !slot_batched->can_batch_with(slot)) {
continue;
}
// this slot still has a prompt to be processed
@@ -4028,6 +4029,10 @@ struct server_context {
}
}
if (!slot_batched) {
slot_batched = &slot;
}
if (batch.n_tokens >= n_batch) {
break;
}
@@ -4431,7 +4436,7 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
SRV_DBG("response: %s\n", res.body.c_str());
}
static void res_error(httplib::Response & res, const json & error_data) {
static void res_err(httplib::Response & res, const json & error_data) {
json final_response {{"error", error_data}};
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500);
@@ -4524,7 +4529,7 @@ int main(int argc, char ** argv) {
try {
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
res_error(res, formatted_error);
res_err(res, formatted_error);
} catch (const std::exception & e) {
LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
}
@@ -4532,9 +4537,9 @@ int main(int argc, char ** argv) {
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 404) {
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
}
// for other error codes, we skip processing here because it's already done by res_error()
// for other error codes, we skip processing here because it's already done by res_err()
});
// set timeouts and change hostname and port
@@ -4591,7 +4596,7 @@ int main(int argc, char ** argv) {
}
// API key is invalid or not provided
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
res_err(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
LOG_WRN("Unauthorized: Invalid API Key\n");
@@ -4609,7 +4614,7 @@ int main(int argc, char ** argv) {
// allow the models endpoint to be accessed during loading
return true;
} else {
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
}
return false;
}
@@ -4648,7 +4653,7 @@ int main(int argc, char ** argv) {
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
if (!params.endpoint_slots) {
res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
res_err(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -4666,7 +4671,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
res_err(res, result->to_json());
return;
}
@@ -4677,7 +4682,7 @@ int main(int argc, char ** argv) {
// optionally return "fail_on_no_slot" error
if (req.has_param("fail_on_no_slot")) {
if (res_task->n_idle_slots == 0) {
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
res_err(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
return;
}
}
@@ -4687,7 +4692,7 @@ int main(int argc, char ** argv) {
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
if (!params.endpoint_metrics) {
res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
res_err(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -4705,7 +4710,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
res_err(res, result->to_json());
return;
}
@@ -4790,7 +4795,7 @@ int main(int argc, char ** argv) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::string filepath = params.slot_save_path + filename;
@@ -4811,7 +4816,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
res_err(res, result->to_json());
return;
}
@@ -4822,7 +4827,7 @@ int main(int argc, char ** argv) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::string filepath = params.slot_save_path + filename;
@@ -4843,7 +4848,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
res_err(res, result->to_json());
return;
}
@@ -4866,7 +4871,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
res_err(res, result->to_json());
return;
}
@@ -4876,7 +4881,7 @@ int main(int argc, char ** argv) {
const auto handle_slots_action = [&params, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
if (params.slot_save_path.empty()) {
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
res_err(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -4886,7 +4891,7 @@ int main(int argc, char ** argv) {
try {
id_slot = std::stoi(id_slot_str);
} catch (const std::exception &) {
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
return;
}
@@ -4899,7 +4904,7 @@ int main(int argc, char ** argv) {
} else if (action == "erase") {
handle_slots_erase(req, res, id_slot);
} else {
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
}
};
@@ -4947,7 +4952,7 @@ int main(int argc, char ** argv) {
const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.endpoint_props) {
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
res_err(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -5044,7 +5049,7 @@ int main(int argc, char ** argv) {
rd->post_tasks(std::move(tasks));
} catch (const std::exception & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
return;
}
@@ -5056,7 +5061,7 @@ int main(int argc, char ** argv) {
if (all_results.is_terminated) {
return; // connection is closed
} else if (all_results.error) {
res_error(res, all_results.error->to_json());
res_err(res, all_results.error->to_json());
return;
} else {
json arr = json::array();
@@ -5076,7 +5081,7 @@ int main(int argc, char ** argv) {
if (first_result == nullptr) {
return; // connection is closed
} else if (first_result->is_error()) {
res_error(res, first_result->to_json());
res_err(res, first_result->to_json());
return;
} else {
GGML_ASSERT(
@@ -5183,7 +5188,7 @@ int main(int argc, char ** argv) {
err += "middle token is missing. ";
}
if (!err.empty()) {
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
res_err(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -5192,20 +5197,20 @@ int main(int argc, char ** argv) {
// validate input
if (data.contains("prompt") && !data.at("prompt").is_string()) {
// prompt is optional
res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
}
if (!data.contains("input_prefix")) {
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (!data.contains("input_suffix")) {
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
// input_extra is optional
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
return;
}
@@ -5213,12 +5218,12 @@ int main(int argc, char ** argv) {
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
return;
}
// filename is optional
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
return;
}
}
@@ -5380,12 +5385,12 @@ int main(int argc, char ** argv) {
const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
if (!ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}
@@ -5399,7 +5404,7 @@ int main(int argc, char ** argv) {
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
prompt = body.at("content");
} else {
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}
@@ -5409,7 +5414,7 @@ int main(int argc, char ** argv) {
if (format == "base64") {
use_base64 = true;
} else if (format != "float") {
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
return;
}
}
@@ -5418,7 +5423,7 @@ int main(int argc, char ** argv) {
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) {
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
return;
}
}
@@ -5459,7 +5464,7 @@ int main(int argc, char ** argv) {
if (all_results.is_terminated) {
return; // connection is closed
} else if (all_results.error) {
res_error(res, all_results.error->to_json());
res_err(res, all_results.error->to_json());
return;
} else {
for (auto & res : all_results.results) {
@@ -5485,7 +5490,7 @@ int main(int argc, char ** argv) {
const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
res_err(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -5500,18 +5505,18 @@ int main(int argc, char ** argv) {
if (body.count("query") == 1) {
query = body.at("query");
if (!query.is_string()) {
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
return;
}
} else {
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::vector<std::string> documents = json_value(body, "documents",
json_value(body, "texts", std::vector<std::string>()));
if (documents.empty()) {
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
return;
}
@@ -5541,7 +5546,7 @@ int main(int argc, char ** argv) {
if (all_results.is_terminated) {
return; // connection is closed
} else if (all_results.error) {
res_error(res, all_results.error->to_json());
res_err(res, all_results.error->to_json());
return;
} else {
for (auto & res : all_results.results) {
@@ -5594,7 +5599,7 @@ int main(int argc, char ** argv) {
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
if (!body.is_array()) {
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
res_err(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
return;
}
@@ -5612,7 +5617,7 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
res_err(res, result->to_json());
return;
}
@@ -1,6 +1,5 @@
<script lang="ts">
import { X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { RemoveButton } from '$lib/components/app';
import { formatFileSize, getFileTypeLabel, getPreviewText } from '$lib/utils/file-preview';
import { FileTypeCategory, MimeTypeText } from '$lib/enums/files';
@@ -66,17 +65,15 @@
</button>
{:else}
<!-- Non-readonly mode (ChatForm) -->
<div class="relative rounded-lg border border-border bg-muted p-3 {className} w-64">
<Button
type="button"
variant="ghost"
size="sm"
class="absolute top-2 right-2 h-6 w-6 bg-white/20 p-0 hover:bg-white/30"
onclick={() => onRemove?.(id)}
aria-label="Remove file"
>
<X class="h-3 w-3" />
</Button>
<button
class="group relative rounded-lg border border-border bg-muted p-3 {className} {textContent
? 'max-h-24 max-w-72'
: 'max-w-36'} cursor-pointer text-left"
onclick={onClick}
>
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<RemoveButton {id} {onRemove} />
</div>
<div class="pr-8">
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
@@ -85,7 +82,7 @@
<div class="relative">
<div
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
style="max-height: 3.6em; line-height: 1.2em;"
style="max-height: 3rem; line-height: 1.2em;"
>
{getPreviewText(textContent)}
</div>
@@ -98,11 +95,11 @@
</div>
{/if}
</div>
</div>
</button>
{/if}
{:else}
<button
class="flex items-center gap-2 gap-3 rounded-lg border border-border bg-muted p-3 {className}"
class="group flex items-center gap-3 rounded-lg border border-border bg-muted p-3 {className} relative"
onclick={onClick}
>
<div
@@ -112,7 +109,9 @@
</div>
<div class="flex flex-col gap-1">
<span class="max-w-36 truncate text-sm font-medium text-foreground md:max-w-72">
<span
class="max-w-24 truncate text-sm font-medium text-foreground group-hover:pr-6 md:max-w-32"
>
{name}
</span>
@@ -122,18 +121,9 @@
</div>
{#if !readonly}
<Button
type="button"
variant="ghost"
size="sm"
class="h-6 w-6 p-0"
onclick={(e) => {
e.stopPropagation();
onRemove?.(id);
}}
>
<X class="h-3 w-3" />
</Button>
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
<RemoveButton {id} {onRemove} />
</div>
{/if}
</button>
{/if}
@@ -1,6 +1,5 @@
<script lang="ts">
import { X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { RemoveButton } from '$lib/components/app';
interface Props {
id: string;
@@ -26,12 +25,12 @@
class: className = '',
// Default to small size for form previews
width = 'w-auto',
height = 'h-24',
height = 'h-16',
imageClass = ''
}: Props = $props();
</script>
<div class="relative overflow-hidden rounded-lg border border-border bg-muted {className}">
<div class="group relative overflow-hidden rounded-lg border border-border bg-muted {className}">
{#if onClick}
<button
type="button"
@@ -55,17 +54,9 @@
{#if !readonly}
<div
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity hover:opacity-100"
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
>
<Button
type="button"
variant="ghost"
size="sm"
class="h-6 w-6 bg-white/20 p-0 text-white hover:bg-white/30"
onclick={() => onRemove?.(id)}
>
<X class="h-3 w-3" />
</Button>
<RemoveButton {id} {onRemove} class="text-white" />
</div>
{/if}
</div>
@@ -153,7 +153,7 @@
<Dialog.Root bind:open>
<Dialog.Content class="grid max-h-[90vh] max-w-5xl overflow-hidden !p-10 sm:w-auto sm:max-w-6xl">
<Dialog.Header class="flex-shrink-0">
<div class="flex items-center justify-between">
<div class="flex items-center justify-between gap-6">
<div class="flex items-center gap-3">
{#if IconComponent}
<IconComponent class="h-5 w-5 text-muted-foreground" />
@@ -1,11 +1,16 @@
<script lang="ts">
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
import { FileTypeCategory } from '$lib/enums/files';
import { getFileTypeCategory } from '$lib/utils/file-type';
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
import ChatAttachmentsViewAllDialog from './ChatAttachmentsViewAllDialog.svelte';
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
interface Props {
class?: string;
style?: string;
// For ChatMessage - stored attachments
attachments?: DatabaseMessageExtra[];
readonly?: boolean;
@@ -16,10 +21,13 @@
imageClass?: string;
imageHeight?: string;
imageWidth?: string;
// Limit display to single row with "+ X more" button
limitToSingleRow?: boolean;
}
let {
class: className = '',
style = '',
attachments = [],
readonly = false,
onFileRemove,
@@ -27,36 +35,23 @@
// Default to small size for form previews
imageClass = '',
imageHeight = 'h-24',
imageWidth = 'w-auto'
imageWidth = 'w-auto',
limitToSingleRow = false
}: Props = $props();
let displayItems = $derived(getDisplayItems());
// Preview dialog state
let canScrollLeft = $state(false);
let canScrollRight = $state(false);
let isScrollable = $state(false);
let previewDialogOpen = $state(false);
let previewItem = $state<{
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
preview?: string;
name?: string;
type?: string;
size?: number;
textContent?: string;
} | null>(null);
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
let scrollContainer: HTMLDivElement | undefined = $state();
let showViewAll = $derived(limitToSingleRow && displayItems.length > 0 && isScrollable);
let viewAllDialogOpen = $state(false);
function getDisplayItems() {
const items: Array<{
id: string;
name: string;
size?: number;
preview?: string;
type: string;
isImage: boolean;
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
attachmentIndex?: number;
textContent?: string;
}> = [];
function getDisplayItems(): ChatAttachmentDisplayItem[] {
const items: ChatAttachmentDisplayItem[] = [];
// Add uploaded files (ChatForm)
for (const file of uploadedFiles) {
@@ -127,14 +122,12 @@
}
}
return items;
return items.reverse();
}
function openPreview(item: (typeof displayItems)[0], event?: Event) {
if (event) {
event.preventDefault();
event.stopPropagation();
}
function openPreview(item: ChatAttachmentDisplayItem, event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
previewItem = {
uploadedFile: item.uploadedFile,
@@ -147,38 +140,118 @@
};
previewDialogOpen = true;
}
function scrollLeft(event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
if (!scrollContainer) return;
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * -0.67, behavior: 'smooth' });
}
function scrollRight(event?: MouseEvent) {
event?.stopPropagation();
event?.preventDefault();
if (!scrollContainer) return;
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * 0.67, behavior: 'smooth' });
}
function updateScrollButtons() {
if (!scrollContainer) return;
const { scrollLeft, scrollWidth, clientWidth } = scrollContainer;
canScrollLeft = scrollLeft > 0;
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1;
isScrollable = scrollWidth > clientWidth;
}
$effect(() => {
if (scrollContainer && displayItems.length) {
scrollContainer.scrollLeft = 0;
setTimeout(() => {
updateScrollButtons();
}, 0);
}
});
</script>
{#if displayItems.length > 0}
<div class="flex flex-wrap items-start {readonly ? 'justify-end' : ''} gap-3 {className}">
{#each displayItems as item (item.id)}
{#if item.isImage && item.preview}
<ChatAttachmentImagePreview
class="cursor-pointer"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{:else}
<ChatAttachmentFilePreview
class="cursor-pointer"
id={item.id}
name={item.name}
type={item.type}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
onClick={(event) => openPreview(item, event)}
/>
{/if}
{/each}
<div class={className} {style}>
<div class="relative">
<button
class="absolute top-1/2 left-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollLeft
? 'opacity-100'
: 'pointer-events-none opacity-0'}"
onclick={scrollLeft}
aria-label="Scroll left"
>
<ChevronLeft class="h-4 w-4" />
</button>
<div
class="scrollbar-hide flex items-start gap-3 overflow-x-auto"
bind:this={scrollContainer}
onscroll={updateScrollButtons}
>
{#each displayItems as item (item.id)}
{#if item.isImage && item.preview}
<ChatAttachmentImagePreview
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{:else}
<ChatAttachmentFilePreview
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
id={item.id}
name={item.name}
type={item.type}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
onClick={(event) => openPreview(item, event)}
/>
{/if}
{/each}
</div>
<button
class="absolute top-1/2 right-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollRight
? 'opacity-100'
: 'pointer-events-none opacity-0'}"
onclick={scrollRight}
aria-label="Scroll right"
>
<ChevronRight class="h-4 w-4" />
</button>
</div>
{#if showViewAll}
<div class="mt-2 -mr-2 flex justify-end px-4">
<Button
type="button"
variant="ghost"
size="sm"
class="h-6 text-xs text-muted-foreground hover:text-foreground"
onclick={() => (viewAllDialogOpen = true)}
>
View all
</Button>
</div>
{/if}
</div>
{/if}
@@ -194,3 +267,13 @@
textContent={previewItem.textContent}
/>
{/if}
<ChatAttachmentsViewAllDialog
bind:open={viewAllDialogOpen}
{uploadedFiles}
{attachments}
{readonly}
{onFileRemove}
imageHeight="h-64"
{imageClass}
/>
@@ -0,0 +1,203 @@
<script lang="ts">
import * as Dialog from '$lib/components/ui/dialog';
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
import { FileTypeCategory } from '$lib/enums/files';
import { getFileTypeCategory } from '$lib/utils/file-type';
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
interface Props {
open?: boolean;
uploadedFiles?: ChatUploadedFile[];
attachments?: DatabaseMessageExtra[];
readonly?: boolean;
onFileRemove?: (fileId: string) => void;
imageHeight?: string;
imageWidth?: string;
imageClass?: string;
}
let {
open = $bindable(false),
uploadedFiles = [],
attachments = [],
readonly = false,
onFileRemove,
imageHeight = 'h-24',
imageWidth = 'w-auto',
imageClass = ''
}: Props = $props();
let previewDialogOpen = $state(false);
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
let displayItems = $derived(getDisplayItems());
let imageItems = $derived(displayItems.filter((item) => item.isImage));
let fileItems = $derived(displayItems.filter((item) => !item.isImage));
function getDisplayItems(): ChatAttachmentDisplayItem[] {
const items: ChatAttachmentDisplayItem[] = [];
for (const file of uploadedFiles) {
items.push({
id: file.id,
name: file.name,
size: file.size,
preview: file.preview,
type: file.type,
isImage: getFileTypeCategory(file.type) === FileTypeCategory.IMAGE,
uploadedFile: file,
textContent: file.textContent
});
}
for (const [index, attachment] of attachments.entries()) {
if (attachment.type === 'imageFile') {
items.push({
id: `attachment-${index}`,
name: attachment.name,
preview: attachment.base64Url,
type: 'image',
isImage: true,
attachment,
attachmentIndex: index
});
} else if (attachment.type === 'textFile') {
items.push({
id: `attachment-${index}`,
name: attachment.name,
type: 'text',
isImage: false,
attachment,
attachmentIndex: index,
textContent: attachment.content
});
} else if (attachment.type === 'context') {
// Legacy format from old webui - treat as text file
items.push({
id: `attachment-${index}`,
name: attachment.name,
type: 'text',
isImage: false,
attachment,
attachmentIndex: index,
textContent: attachment.content
});
} else if (attachment.type === 'audioFile') {
items.push({
id: `attachment-${index}`,
name: attachment.name,
type: attachment.mimeType || 'audio',
isImage: false,
attachment,
attachmentIndex: index
});
} else if (attachment.type === 'pdfFile') {
items.push({
id: `attachment-${index}`,
name: attachment.name,
type: 'application/pdf',
isImage: false,
attachment,
attachmentIndex: index,
textContent: attachment.content
});
}
}
return items.reverse();
}
function openPreview(item: (typeof displayItems)[0], event?: Event) {
if (event) {
event.preventDefault();
event.stopPropagation();
}
previewItem = {
uploadedFile: item.uploadedFile,
attachment: item.attachment,
preview: item.preview,
name: item.name,
type: item.type,
size: item.size,
textContent: item.textContent
};
previewDialogOpen = true;
}
</script>
<Dialog.Root bind:open>
<Dialog.Portal>
<Dialog.Overlay />
<Dialog.Content class="flex !max-h-[90vh] !max-w-6xl flex-col">
<Dialog.Header>
<Dialog.Title>All Attachments ({displayItems.length})</Dialog.Title>
<Dialog.Description class="text-sm text-muted-foreground">
View and manage all attached files
</Dialog.Description>
</Dialog.Header>
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
{#if fileItems.length > 0}
<div>
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
<div class="flex flex-wrap items-start gap-3">
{#each fileItems as item (item.id)}
<ChatAttachmentFilePreview
class="cursor-pointer"
id={item.id}
name={item.name}
type={item.type}
size={item.size}
{readonly}
onRemove={onFileRemove}
textContent={item.textContent}
onClick={(event) => openPreview(item, event)}
/>
{/each}
</div>
</div>
{/if}
{#if imageItems.length > 0}
<div>
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
<div class="flex flex-wrap items-start gap-3">
{#each imageItems as item (item.id)}
{#if item.preview}
<ChatAttachmentImagePreview
class="cursor-pointer"
id={item.id}
name={item.name}
preview={item.preview}
{readonly}
onRemove={onFileRemove}
height={imageHeight}
width={imageWidth}
{imageClass}
onClick={(event) => openPreview(item, event)}
/>
{/if}
{/each}
</div>
</div>
{/if}
</div>
</Dialog.Content>
</Dialog.Portal>
</Dialog.Root>
{#if previewItem}
<ChatAttachmentPreviewDialog
bind:open={previewDialogOpen}
uploadedFile={previewItem.uploadedFile}
attachment={previewItem.attachment}
preview={previewItem.preview}
name={previewItem.name}
type={previewItem.type}
size={previewItem.size}
textContent={previewItem.textContent}
/>
{/if}
@@ -232,7 +232,13 @@
onsubmit={handleSubmit}
class="{INPUT_CLASSES} border-radius-bottom-none mx-auto max-w-[48rem] overflow-hidden rounded-3xl backdrop-blur-md {className}"
>
<ChatAttachmentsList bind:uploadedFiles {onFileRemove} class="mb-3 px-5 pt-5" />
<ChatAttachmentsList
bind:uploadedFiles
{onFileRemove}
limitToSingleRow
class="py-5"
style="scroll-padding: 1rem;"
/>
<div
class="flex-column relative min-h-[48px] items-center rounded-3xl px-5 py-3 shadow-sm transition-all focus-within:shadow-md"
@@ -333,7 +333,7 @@
ondrop={handleDrop}
role="main"
>
<div class="w-full max-w-2xl px-4">
<div class="w-full max-w-[48rem] px-4">
<div class="mb-8 text-center" in:fade={{ duration: 300 }}>
<h1 class="mb-2 text-3xl font-semibold tracking-tight">llama.cpp</h1>
@@ -368,7 +368,7 @@
<AlertDialog.Portal>
<AlertDialog.Overlay />
<AlertDialog.Content class="max-w-md">
<AlertDialog.Content class="flex max-w-md flex-col">
<AlertDialog.Header>
<AlertDialog.Title>File Upload Error</AlertDialog.Title>
@@ -377,7 +377,7 @@
</AlertDialog.Description>
</AlertDialog.Header>
<div class="space-y-4">
<div class="!max-h-[50vh] min-h-0 flex-1 space-y-4 overflow-y-auto">
{#if fileErrorData.generallyUnsupported.length > 0}
<div class="space-y-2">
<h4 class="text-sm font-medium text-destructive">Unsupported File Types</h4>
@@ -398,8 +398,6 @@
{#if fileErrorData.modalityUnsupported.length > 0}
<div class="space-y-2">
<h4 class="text-sm font-medium text-destructive">Model Compatibility Issues</h4>
<div class="space-y-1">
{#each fileErrorData.modalityUnsupported as file (file.name)}
<div class="rounded-md bg-destructive/10 px-3 py-2">
@@ -415,14 +413,14 @@
</div>
</div>
{/if}
</div>
<div class="rounded-md bg-muted/50 p-3">
<h4 class="mb-2 text-sm font-medium">This model supports:</h4>
<div class="rounded-md bg-muted/50 p-3">
<h4 class="mb-2 text-sm font-medium">This model supports:</h4>
<p class="text-sm text-muted-foreground">
{fileErrorData.supportedTypes.join(', ')}
</p>
</div>
<p class="text-sm text-muted-foreground">
{fileErrorData.supportedTypes.join(', ')}
</p>
</div>
<AlertDialog.Footer>
@@ -2,6 +2,7 @@ export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttac
export { default as ChatAttachmentFilePreview } from './chat/ChatAttachments/ChatAttachmentFilePreview.svelte';
export { default as ChatAttachmentImagePreview } from './chat/ChatAttachments/ChatAttachmentImagePreview.svelte';
export { default as ChatAttachmentPreviewDialog } from './chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte';
export { default as ChatAttachmentsViewAllDialog } from './chat/ChatAttachments/ChatAttachmentsViewAllDialog.svelte';
export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte';
export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte';
@@ -42,6 +43,8 @@ export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.sve
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
export { default as RemoveButton } from './misc/RemoveButton.svelte';
export { default as ServerStatus } from './server/ServerStatus.svelte';
export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte';
export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte';
@@ -0,0 +1,26 @@
<script lang="ts">
import { X } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
interface Props {
id: string;
onRemove?: (id: string) => void;
class?: string;
}
let { id, onRemove, class: className = '' }: Props = $props();
</script>
<Button
type="button"
variant="ghost"
size="sm"
class="h-6 w-6 bg-white/20 p-0 hover:bg-white/30 {className}"
onclick={(e) => {
e.stopPropagation();
onRemove?.(id);
}}
aria-label="Remove file"
>
<X class="h-3 w-3" />
</Button>
+23
View File
@@ -11,6 +11,29 @@ export interface ChatUploadedFile {
textContent?: string;
}
export interface ChatAttachmentDisplayItem {
id: string;
name: string;
size?: number;
preview?: string;
type: string;
isImage: boolean;
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
attachmentIndex?: number;
textContent?: string;
}
export interface ChatAttachmentPreviewItem {
uploadedFile?: ChatUploadedFile;
attachment?: DatabaseMessageExtra;
preview?: string;
name?: string;
type?: string;
size?: number;
textContent?: string;
}
export interface ChatMessageSiblingInfo {
message: DatabaseMessage;
siblingIds: string[];