mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 17:47:40 +02:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1568d13c2c | |||
| 439342ea0b | |||
| 234ae7d7bd | |||
| 38eaf32af1 | |||
| 9b17d74ab7 | |||
| e1fcf8b09b | |||
| 6cd0cf72ce | |||
| d396b43748 | |||
| 45c6ef7307 | |||
| 2606b0adab | |||
| 307772fcda | |||
| f1bad23f88 | |||
| becc4816dd | |||
| c4abcb2457 | |||
| 389ac78b26 | |||
| a19bd6f7ce |
@@ -9,7 +9,7 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
|
||||
- **Size**: ~200k+ lines of code across 1000+ files
|
||||
- **Architecture**: Modular design with main library (`libllama`) and 40+ executable tools/examples
|
||||
- **Core dependency**: ggml tensor library (vendored in `ggml/` directory)
|
||||
- **Backends supported**: CPU (AVX/NEON optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
||||
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
||||
- **License**: MIT
|
||||
|
||||
## Build Instructions
|
||||
|
||||
@@ -61,6 +61,7 @@ range of hardware - locally and in the cloud.
|
||||
- Plain C/C++ implementation without any dependencies
|
||||
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
|
||||
- AVX, AVX2, AVX512 and AMX support for x86 architectures
|
||||
- RVV, ZVFH, ZFH and ZICBOP support for RISC-V architectures
|
||||
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
||||
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
|
||||
- Vulkan and SYCL backend support
|
||||
|
||||
+1
-5
@@ -355,11 +355,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
void common_init() {
|
||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}, NULL);
|
||||
llama_log_set(common_log_default_callback, NULL);
|
||||
|
||||
#ifdef NDEBUG
|
||||
const char * build_type = "";
|
||||
|
||||
@@ -442,3 +442,9 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||
log->set_timestamps(timestamps);
|
||||
}
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,8 @@ extern int common_log_verbosity_thold;
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
||||
|
||||
// the common_log uses an internal worker thread to print/write log messages
|
||||
// when the worker thread is paused, incoming log messages are discarded
|
||||
struct common_log;
|
||||
|
||||
@@ -1124,6 +1124,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
|
||||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
|
||||
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
|
||||
res = "afmoe"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
|
||||
res = "bailingmoe2"
|
||||
@@ -2533,6 +2536,81 @@ class ArceeModel(LlamaModel):
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
|
||||
@ModelBase.register("AfmoeForCausalLM")
|
||||
class AfmoeModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.AFMOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# MoE parameters
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
|
||||
self.gguf_writer.add_expert_shared_count(n_shared_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
|
||||
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
|
||||
|
||||
# Expert Gating Function
|
||||
score_func = self.hparams.get("score_func")
|
||||
if score_func == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif score_func == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
elif score_func is not None:
|
||||
raise ValueError(f"Unsupported score_function value: {score_func}")
|
||||
|
||||
# Route normalization and scaling
|
||||
if (route_norm := self.hparams.get("route_norm")) is not None:
|
||||
self.gguf_writer.add_expert_weights_norm(route_norm)
|
||||
if (route_scale := self.hparams.get("route_scale")) is not None:
|
||||
self.gguf_writer.add_expert_weights_scale(route_scale)
|
||||
|
||||
# Sliding window attention
|
||||
if (sliding_window := self.hparams.get("sliding_window")) is not None:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Handle expert weights - they're already merged in the HF format
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["gate_proj", "up_proj", "down_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename_to_retrieve])
|
||||
del self._experts[bid][ename_to_retrieve]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"LlavaForConditionalGeneration", # pixtral
|
||||
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||
|
||||
@@ -139,6 +139,7 @@ models = [
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||
|
||||
+31
-25
@@ -14,33 +14,36 @@ Legend:
|
||||
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
@@ -54,25 +57,25 @@ Legend:
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
@@ -80,15 +83,15 @@ Legend:
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
@@ -96,21 +99,24 @@ Legend:
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
+16067
-5133
File diff suppressed because it is too large
Load Diff
+16224
-6894
File diff suppressed because it is too large
Load Diff
+14534
-4358
File diff suppressed because it is too large
Load Diff
@@ -475,6 +475,7 @@ extern "C" {
|
||||
GGML_OP_COS,
|
||||
GGML_OP_SUM,
|
||||
GGML_OP_SUM_ROWS,
|
||||
GGML_OP_CUMSUM,
|
||||
GGML_OP_MEAN,
|
||||
GGML_OP_ARGMAX,
|
||||
GGML_OP_COUNT_EQUAL,
|
||||
@@ -530,6 +531,8 @@ extern "C" {
|
||||
GGML_OP_TIMESTEP_EMBEDDING,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
GGML_OP_TRI,
|
||||
GGML_OP_FILL,
|
||||
|
||||
GGML_OP_FLASH_ATTN_EXT,
|
||||
GGML_OP_FLASH_ATTN_BACK,
|
||||
@@ -542,6 +545,7 @@ extern "C" {
|
||||
GGML_OP_RWKV_WKV6,
|
||||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
GGML_OP_RWKV_WKV7,
|
||||
GGML_OP_SOLVE_TRI,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
@@ -576,6 +580,8 @@ extern "C" {
|
||||
GGML_UNARY_OP_HARDSWISH,
|
||||
GGML_UNARY_OP_HARDSIGMOID,
|
||||
GGML_UNARY_OP_EXP,
|
||||
GGML_UNARY_OP_EXPM1,
|
||||
GGML_UNARY_OP_SOFTPLUS,
|
||||
GGML_UNARY_OP_GELU_ERF,
|
||||
GGML_UNARY_OP_XIELU,
|
||||
GGML_UNARY_OP_FLOOR,
|
||||
@@ -620,6 +626,13 @@ extern "C" {
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
};
|
||||
|
||||
enum ggml_tri_type {
|
||||
GGML_TRI_TYPE_UPPER_DIAG = 0,
|
||||
GGML_TRI_TYPE_UPPER = 1,
|
||||
GGML_TRI_TYPE_LOWER_DIAG = 2,
|
||||
GGML_TRI_TYPE_LOWER = 3
|
||||
};
|
||||
|
||||
struct ggml_init_params {
|
||||
// memory pool
|
||||
size_t mem_size; // bytes
|
||||
@@ -957,6 +970,22 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_expm1(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_expm1_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_softplus(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_softplus_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_sin(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
@@ -983,6 +1012,10 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cumsum(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// mean along rows
|
||||
GGML_API struct ggml_tensor * ggml_mean(
|
||||
struct ggml_context * ctx,
|
||||
@@ -2187,6 +2220,23 @@ extern "C" {
|
||||
int shift2,
|
||||
int shift3);
|
||||
|
||||
// Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
|
||||
// zeroes everywhere outside the masked area
|
||||
GGML_API struct ggml_tensor * ggml_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_tri_type type);
|
||||
|
||||
// Fill tensor a with constant c
|
||||
GGML_API struct ggml_tensor * ggml_fill(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_fill_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c);
|
||||
|
||||
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||||
// timesteps: [N,]
|
||||
@@ -2356,6 +2406,27 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
|
||||
* without zeroes on the diagonal (i.e. invertible).
|
||||
* B can have any number of columns, but must have the same number of rows as A
|
||||
* If A is [n, n] and B is [n, m], then the result will be [n, m] as well
|
||||
* Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
|
||||
* where n > 100 sparingly, pre-chunk if necessary.
|
||||
*
|
||||
* If left = false, solves xA=B instead
|
||||
* If lower = false, assumes upper triangular instead
|
||||
* If uni = true, assumes diagonal of A to be all ones (will override actual values)
|
||||
*
|
||||
* TODO: currently only lower, right, non-unitriangular variant is implemented
|
||||
*/
|
||||
GGML_API struct ggml_tensor * ggml_solve_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
bool left,
|
||||
bool lower,
|
||||
bool uni);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||
|
||||
@@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_sum_rows(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CUMSUM:
|
||||
{
|
||||
ggml_compute_forward_cumsum(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
ggml_compute_forward_mean(params, tensor);
|
||||
@@ -1927,6 +1931,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_leaky_relu(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
ggml_compute_forward_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
ggml_compute_forward_fill(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
ggml_compute_forward_flash_attn_ext(params, tensor);
|
||||
@@ -1982,6 +1994,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
ggml_compute_forward_solve_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
{
|
||||
ggml_compute_forward_map_custom1(params, tensor);
|
||||
@@ -2140,6 +2156,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@@ -2157,6 +2176,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@@ -2179,6 +2199,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
|
||||
+208
-2
@@ -9,6 +9,7 @@
|
||||
|
||||
#include <cfloat>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
// ggml_compute_forward_dup
|
||||
@@ -1395,6 +1396,56 @@ void ggml_compute_forward_sum(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_cumsum
|
||||
|
||||
static void ggml_compute_forward_cumsum_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne03);
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_cumsum(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_cumsum_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_sum_rows
|
||||
|
||||
static void ggml_compute_forward_sum_rows_f32(
|
||||
@@ -2141,6 +2192,83 @@ static void ggml_compute_forward_gelu(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_fill
|
||||
|
||||
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const float c = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, dst);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne2*ne1);
|
||||
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
||||
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
ggml_vec_set_f32(ne0, dst_ptr, c);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
ggml_compute_forward_fill_f32(params, dst);
|
||||
}
|
||||
|
||||
// ggml_compute_tri
|
||||
|
||||
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
bool (*bipred)(int, int);
|
||||
|
||||
switch (ttype) {
|
||||
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
||||
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
||||
default: GGML_ABORT("invalid tri type");
|
||||
}
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
for (int i0 = 0; i0 < ne0; ++i0) {
|
||||
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_tri_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_gelu_erf
|
||||
|
||||
static void ggml_compute_forward_gelu_erf_f32(
|
||||
@@ -8536,7 +8664,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
||||
const float dA = expf(dt_soft_plus * A[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
@@ -8633,7 +8761,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
// dim
|
||||
@@ -8916,6 +9044,14 @@ void ggml_compute_forward_unary(
|
||||
{
|
||||
ggml_compute_forward_xielu(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
{
|
||||
ggml_compute_forward_expm1(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
{
|
||||
ggml_compute_forward_softplus(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -9512,6 +9648,76 @@ void ggml_compute_forward_gla(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ne00 == ne01); // A must be square
|
||||
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
||||
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
||||
|
||||
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
||||
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t k = ne10; // number of RHS columns
|
||||
const int64_t n = ne11; // A is n×n
|
||||
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
||||
|
||||
// chunks per thread
|
||||
const int64_t dr = (nr + nth - 1)/nth;
|
||||
|
||||
// chunk range for this thread
|
||||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
||||
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
||||
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*k);
|
||||
const int64_t i02 = (ir - i03*ne02*k)/k;
|
||||
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
||||
|
||||
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
||||
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
||||
|
||||
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
||||
|
||||
for (int64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (int64_t t = 0; t < i00; ++t) {
|
||||
sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
|
||||
}
|
||||
|
||||
const float diag = A_batch[i00 * n + i00];
|
||||
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||
|
||||
X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_compute_forward_solve_tri_f32(params, dst);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_rwkv_wkv7
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
|
||||
@@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct
|
||||
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
@@ -81,6 +82,8 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
|
||||
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_flash_attn_back(
|
||||
const struct ggml_compute_params * params,
|
||||
@@ -96,6 +99,7 @@ void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params,
|
||||
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -1600,29 +1600,52 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
|
||||
void forward_mul_mat_one_chunk(ggml_compute_params * params,
|
||||
ggml_tensor * op,
|
||||
int64_t src0_start,
|
||||
int64_t src0_end,
|
||||
int64_t src1_start,
|
||||
int64_t src1_end) {
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
ggml_tensor * dst = op;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const void * src1_wdata = params->wdata;
|
||||
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
||||
|
||||
GGML_ASSERT(ne03 == 1 && ne13 == 1);
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
|
||||
const int64_t i12 = src1_start / ne1;
|
||||
const int64_t i11 = src1_start - i12 * ne1;
|
||||
|
||||
// Determine batch index
|
||||
const int64_t i02 = i12 / r2;
|
||||
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
|
||||
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
|
||||
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
|
||||
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
|
||||
|
||||
const int64_t nrows = src1_end - src1_start;
|
||||
const int64_t ncols = src0_end - src0_start;
|
||||
|
||||
GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (ne11 > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
if (nrows > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
|
||||
src0_ptr + src0_start * nb01, src1_ptr,
|
||||
nrows - (nrows % 4), ncols);
|
||||
}
|
||||
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
|
||||
ne01, src0_ptr + src0_start * nb01,
|
||||
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1647,6 +1670,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// TODO: General batched mul mat for 4D tensors
|
||||
// Currently only supports 3D tensors
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne3 == 1);
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
|
||||
@@ -1654,47 +1683,64 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
|
||||
char * wdata = static_cast<char *>(params->wdata);
|
||||
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
|
||||
assert(params->wsize >= nbw1 * ne11);
|
||||
assert(params->wsize >= nbw2 * ne12);
|
||||
|
||||
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
||||
|
||||
int64_t i11_processed = 0;
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
|
||||
}
|
||||
// INFO: Quantization is done in planes to avoid extra complexity in chunking.
|
||||
// Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
|
||||
// the planes are broadcast.
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
char * data_ptr = (char *) src1->data + i12 * nb12;
|
||||
char * wdata_ptr = wdata + i12 * nbw2;
|
||||
|
||||
i11_processed = ne11 - ne11 % 4;
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
||||
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
|
||||
}
|
||||
|
||||
const int64_t i11_processed = ne11 - ne11 % 4;
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
|
||||
}
|
||||
}
|
||||
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
|
||||
// 4x chunks per thread
|
||||
int64_t nr = ggml_nrows(op->src[0]);
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||
const int64_t nr0 = ggml_nrows(op->src[0]);
|
||||
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
|
||||
|
||||
// src1 is chunked only by full planes.
|
||||
// When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
|
||||
// to route them thorugh GEMV.
|
||||
// nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
|
||||
// to avoid affecting their performance
|
||||
int64_t nchunk1 = ne12;
|
||||
|
||||
// Ensure minimum chunk size to avoid alignment issues with high thread counts
|
||||
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
|
||||
const int64_t min_chunk_size = NB_COLS;
|
||||
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
|
||||
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
|
||||
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
|
||||
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||
}
|
||||
|
||||
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||
nchunk = nth;
|
||||
if (nth == 1 || nchunk0 < nth || disable_chunking) {
|
||||
nchunk0 = nth;
|
||||
}
|
||||
|
||||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
|
||||
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
|
||||
// This prevents creating too many tiny chunks that could overlap after alignment
|
||||
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
|
||||
if (nchunk > max_nchunk) {
|
||||
nchunk = max_nchunk;
|
||||
}
|
||||
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||
nchunk0 = MIN(nchunk0, max_nchunk);
|
||||
|
||||
if (ith == 0) {
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
@@ -1706,23 +1752,30 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
int64_t src0_start = (current_chunk * ne01) / nchunk;
|
||||
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
|
||||
int64_t src0_start = dr0 * ith0;
|
||||
int64_t src0_end = MIN(src0_start + dr0, nr0);
|
||||
|
||||
// full-plane range for src1
|
||||
int64_t src1_start = ith1 * ne11;
|
||||
int64_t src1_end = (ith1 + 1) * ne11;
|
||||
|
||||
// Align boundaries to NB_COLS - round up to ensure all data is included
|
||||
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
|
||||
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
if (src0_end > ne01) {
|
||||
src0_end = ne01;
|
||||
}
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
src0_end = MIN(src0_end, ne01);
|
||||
|
||||
// Make sure current plane is the last one before exiting
|
||||
if (src0_start >= src0_end) {
|
||||
break;
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
continue;
|
||||
}
|
||||
|
||||
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
|
||||
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
|
||||
@@ -73,6 +73,14 @@ static inline float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static inline float op_expm1(float x) {
|
||||
return expf(x) - 1.0f;
|
||||
}
|
||||
|
||||
static inline float op_softplus(float x) {
|
||||
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
|
||||
static inline float op_floor(float x) {
|
||||
return floorf(x);
|
||||
}
|
||||
@@ -290,6 +298,14 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
|
||||
unary_op<op_log>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_expm1>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_softplus>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_floor>(params, dst);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
|
||||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -1416,6 +1416,16 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (i == 0) {
|
||||
y[i] = x[i];
|
||||
} else {
|
||||
y[i] = y[i - 1] + x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
|
||||
ggml_float sum = 0.0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
|
||||
@@ -2527,6 +2527,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
ggml_cuda_op_trunc(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
ggml_cuda_op_expm1(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_cuda_op_softplus(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -3829,6 +3835,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
|
||||
@@ -81,6 +81,14 @@ static __device__ __forceinline__ float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_expm1(float x) {
|
||||
return expm1f(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_softplus(float x) {
|
||||
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_elu(float x) {
|
||||
return (x > 0.f) ? x : expm1f(x);
|
||||
}
|
||||
@@ -233,6 +241,14 @@ void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_trunc>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_expm1>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_softplus>(ctx, dst);
|
||||
}
|
||||
/* gated ops */
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
|
||||
@@ -61,6 +61,10 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) {
|
||||
}
|
||||
}
|
||||
|
||||
static inline float ggml_softplus(float input) {
|
||||
static inline float ggml_compute_softplus_f32(float input) {
|
||||
return (input > 20.0f) ? input : logf(1 + expf(input));
|
||||
}
|
||||
//
|
||||
|
||||
@@ -943,6 +943,34 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_ARGSORT);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
ggml_sort_order order = (ggml_sort_order) op->op_params[0];
|
||||
|
||||
const char * order_str = "undefined";
|
||||
switch (order) {
|
||||
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
||||
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
|
||||
@@ -125,6 +125,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -904,8 +904,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ARGSORT:
|
||||
// TODO: Support arbitrary column width
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
case GGML_OP_ARANGE:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
|
||||
@@ -793,10 +793,28 @@ typedef struct {
|
||||
} ggml_metal_kargs_leaky_relu;
|
||||
|
||||
typedef struct {
|
||||
int64_t ncols;
|
||||
int64_t ncols_pad;
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
} ggml_metal_kargs_argsort;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t len;
|
||||
} ggml_metal_kargs_argsort_merge;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne0;
|
||||
float start;
|
||||
|
||||
@@ -1975,7 +1975,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
||||
// note: always reserve the padding space to avoid graph reallocations
|
||||
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
||||
const bool has_kvpad = true;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
||||
@@ -1984,7 +1986,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
} else {
|
||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
||||
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
||||
const bool has_kvpad = true;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
||||
@@ -2020,9 +2023,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
||||
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
|
||||
|
||||
// this optimization is not useful for the vector kernels
|
||||
if (is_vec) {
|
||||
return res;
|
||||
}
|
||||
// note: always reserve the blk buffer to avoid graph reallocations
|
||||
//if (is_vec) {
|
||||
// return res;
|
||||
//}
|
||||
|
||||
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
||||
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
||||
@@ -2049,13 +2053,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
||||
|
||||
size_t res = 0;
|
||||
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
// note: always reserve the temp buffer to avoid graph reallocations
|
||||
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
if (true) {
|
||||
const int64_t nwg = 32;
|
||||
const int64_t ne01_max = std::min(ne01, 32);
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -3523,38 +3530,95 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int64_t ne00_padded = 1;
|
||||
while (ne00_padded < ne00) {
|
||||
ne00_padded *= 2;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
||||
|
||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int nth = 1;
|
||||
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
const int nptg = (ne00 + nth - 1)/nth;
|
||||
|
||||
// Metal kernels require the buffer size to be multiple of 16 bytes
|
||||
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
||||
const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
|
||||
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += ggml_nbytes(op);
|
||||
|
||||
if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
|
||||
std::swap(bid_dst, bid_tmp);
|
||||
}
|
||||
|
||||
ggml_metal_kargs_argsort args = {
|
||||
/*.ncols =*/ ne00,
|
||||
/*.ncols_pad =*/ ne00_padded
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
||||
|
||||
int len = nth;
|
||||
|
||||
while (len < ne00) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
ggml_metal_kargs_argsort_merge args_merge = {
|
||||
.ne00 = ne00,
|
||||
.ne01 = ne01,
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.nb00 = nb00,
|
||||
.nb01 = nb01,
|
||||
.nb02 = nb02,
|
||||
.nb03 = nb03,
|
||||
.len = len,
|
||||
};
|
||||
|
||||
// merges per row
|
||||
const int nm = (ne00 + 2*len - 1) / (2*len);
|
||||
|
||||
const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
||||
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
std::swap(bid_dst, bid_tmp);
|
||||
|
||||
len <<= 1;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -197,6 +197,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
||||
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
res *= 2;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -4541,69 +4541,179 @@ kernel void kernel_timestep_embedding_f32(
|
||||
// bitonic sort implementation following the CUDA kernels as reference
|
||||
typedef void (argsort_t)(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const float * x,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]);
|
||||
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template<ggml_sort_order order>
|
||||
kernel void kernel_argsort_f32_i32(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const float * x,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
||||
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// bitonic sort
|
||||
int col = tpitg[0];
|
||||
int row = tgpig[1];
|
||||
const int col = tpitg[0];
|
||||
|
||||
if (col >= args.ncols_pad) return;
|
||||
const int i00 = (tgpig[0]/args.ne01)*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * x_row = x + row * args.ncols;
|
||||
threadgroup int32_t * dst_row = shared_values;
|
||||
device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
smem_i32[col] = i00 + col;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int k = 2; k <= args.ncols_pad; k *= 2) {
|
||||
for (int k = 2; k <= ntg.x; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= args.ncols ||
|
||||
(dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
if (smem_i32[col] >= args.ne00 ||
|
||||
(smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
|
||||
x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
SWAP(smem_i32[col], smem_i32[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= args.ncols ||
|
||||
(dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
if (smem_i32[ixj] >= args.ne00 ||
|
||||
(smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
|
||||
x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
SWAP(smem_i32[col], smem_i32[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < args.ncols) {
|
||||
dst[row * args.ncols + col] = dst_row[col];
|
||||
if (i00 + col < args.ne00) {
|
||||
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
dst[col] = smem_i32[col];
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
|
||||
typedef void (argsort_merge_t)(
|
||||
constant ggml_metal_kargs_argsort_merge & args,
|
||||
device const char * src0,
|
||||
device const int32_t * tmp,
|
||||
device int32_t * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template<ggml_sort_order order>
|
||||
kernel void kernel_argsort_merge_f32_i32(
|
||||
constant ggml_metal_kargs_argsort_merge & args,
|
||||
device const char * src0,
|
||||
device const int32_t * tmp,
|
||||
device int32_t * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int im = tgpig[0] / args.ne01;
|
||||
int i01 = tgpig[0] % args.ne01;
|
||||
int i02 = tgpig[1];
|
||||
int i03 = tgpig[2];
|
||||
|
||||
const int start = im * (2*args.len);
|
||||
|
||||
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
|
||||
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
|
||||
|
||||
const int total = len0 + len1;
|
||||
|
||||
device const int32_t * tmp0 = tmp + start
|
||||
+ i01*args.ne00
|
||||
+ i02*args.ne00*args.ne01
|
||||
+ i03*args.ne00*args.ne01*args.ne02;
|
||||
|
||||
device const int32_t * tmp1 = tmp0 + args.len;
|
||||
|
||||
dst += start
|
||||
+ i01*args.ne00
|
||||
+ i02*args.ne00*args.ne01
|
||||
+ i03*args.ne00*args.ne01*args.ne02;
|
||||
|
||||
device const float * src0_row = (device const float *)(src0
|
||||
+ args.nb01*i01
|
||||
+ args.nb02*i02
|
||||
+ args.nb03*i03);
|
||||
|
||||
for (int k = tpitg.x; k < (int) total; k += ntg.x) {
|
||||
// find partition (i,j) such that i+j = k
|
||||
int low = k > len1 ? k - len1 : 0;
|
||||
int high = MIN(k, len0);
|
||||
|
||||
while (low < high) {
|
||||
const int mid = (low + high) >> 1;
|
||||
|
||||
const int32_t idx0 = tmp0[mid];
|
||||
const int32_t idx1 = tmp1[k - mid - 1];
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (val0 <= val1) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
} else {
|
||||
if (val0 >= val1) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int i = low;
|
||||
const int j = k - i;
|
||||
|
||||
int32_t out_idx;
|
||||
|
||||
if (i >= len0) {
|
||||
out_idx = tmp1[j];
|
||||
} else if (j >= len1) {
|
||||
out_idx = tmp0[i];
|
||||
} else {
|
||||
const int32_t idx0 = tmp0[i];
|
||||
const int32_t idx1 = tmp1[j];
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
|
||||
out_idx = (order == GGML_SORT_ORDER_ASC)
|
||||
? (val0 <= val1 ? idx0 : idx1)
|
||||
: (val0 >= val1 ? idx0 : idx1);
|
||||
}
|
||||
|
||||
dst[k] = out_idx;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
|
||||
kernel void kernel_leaky_relu_f32(
|
||||
constant ggml_metal_kargs_leaky_relu & args,
|
||||
device const float * src0,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,21 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[i] = D_TYPE(abs(float(data_a[i])));
|
||||
}
|
||||
@@ -7,6 +7,7 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
@@ -108,6 +109,38 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
float Sf[Br][cols_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
@@ -153,21 +186,6 @@ void main() {
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
|
||||
@@ -148,6 +149,37 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
float mask_cache[Bc * Br / WorkGroupSize];
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
mask_cache[idx / WorkGroupSize] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
@@ -208,7 +240,8 @@ void main() {
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
float f = mask_cache[idx / WorkGroupSize];
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||
return max(x, y);
|
||||
}
|
||||
|
||||
float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
|
||||
return max(x, y);
|
||||
}
|
||||
|
||||
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||
return x;
|
||||
}
|
||||
@@ -142,6 +146,44 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
if (nem1_bounds_check) {
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
@@ -158,31 +200,7 @@ void main() {
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
if (nem1_bounds_check) {
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
}
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
}
|
||||
|
||||
// Clear padding elements to -inf, so they don't contribute to rowmax
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
data_d[i] = D_TYPE(-float(data_a[i]));
|
||||
}
|
||||
@@ -76,7 +76,7 @@ enum MatMulIdType {
|
||||
|
||||
namespace {
|
||||
|
||||
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
HANDLE stdout_read, stdout_write;
|
||||
HANDLE stderr_read, stderr_write;
|
||||
@@ -99,8 +99,10 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
|
||||
si.hStdOutput = stdout_write;
|
||||
si.hStdError = stderr_write;
|
||||
|
||||
std::vector<char> cmd(command.begin(), command.end());
|
||||
cmd.push_back('\0');
|
||||
std::string cmd;
|
||||
for (const auto& part : command) {
|
||||
cmd += part + " ";
|
||||
}
|
||||
|
||||
if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
|
||||
throw std::runtime_error("Failed to create process");
|
||||
@@ -138,6 +140,12 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
|
||||
throw std::runtime_error("Failed to fork process");
|
||||
}
|
||||
|
||||
std::vector<char*> argv;
|
||||
for (std::string& part : command) {
|
||||
argv.push_back(part.data());
|
||||
}
|
||||
argv.push_back(nullptr);
|
||||
|
||||
if (pid == 0) {
|
||||
close(stdout_pipe[0]);
|
||||
close(stderr_pipe[0]);
|
||||
@@ -145,7 +153,7 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
|
||||
dup2(stderr_pipe[1], STDERR_FILENO);
|
||||
close(stdout_pipe[1]);
|
||||
close(stderr_pipe[1]);
|
||||
execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
|
||||
execvp(argv[0], argv.data());
|
||||
_exit(EXIT_FAILURE);
|
||||
} else {
|
||||
close(stdout_pipe[1]);
|
||||
@@ -316,21 +324,27 @@ compile_count_guard acquire_compile_slot() {
|
||||
void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map<std::string, std::string> defines, bool coopmat, bool dep_file, compile_count_guard slot) {
|
||||
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
|
||||
|
||||
#ifdef _WIN32
|
||||
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
|
||||
#else
|
||||
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path};
|
||||
#endif
|
||||
|
||||
// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
|
||||
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
|
||||
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
|
||||
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos || name.find("rope") != std::string::npos) ? "" : "-O";
|
||||
|
||||
#ifdef _WIN32
|
||||
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
|
||||
#else
|
||||
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_path};
|
||||
#endif
|
||||
if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
|
||||
cmd.push_back("-O");
|
||||
}
|
||||
|
||||
if (dep_file) {
|
||||
cmd.push_back("-MD");
|
||||
cmd.push_back("-MF");
|
||||
#ifdef _WIN32
|
||||
cmd.push_back("\"" + target_cpp + ".d\"");
|
||||
#else
|
||||
cmd.push_back(target_cpp + ".d");
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
|
||||
@@ -354,9 +368,13 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
execute_command(command, stdout_str, stderr_str);
|
||||
execute_command(cmd, stdout_str, stderr_str);
|
||||
if (!stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
|
||||
std::cerr << "cannot compile " << name << "\n\n";
|
||||
for (const auto& part : cmd) {
|
||||
std::cerr << part << " ";
|
||||
}
|
||||
std::cerr << "\n\n" << stderr_str << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -430,7 +448,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
|
||||
if (f16acc) {
|
||||
base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
|
||||
base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
}
|
||||
|
||||
if (coopmat) {
|
||||
@@ -610,7 +628,7 @@ void process_shaders() {
|
||||
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
|
||||
if (f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
@@ -809,6 +827,8 @@ void process_shaders() {
|
||||
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("neg_f16", "neg.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("neg_f32", "neg.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
@@ -817,6 +837,8 @@ void process_shaders() {
|
||||
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
@@ -1081,11 +1103,6 @@ int main(int argc, char** argv) {
|
||||
|
||||
if (args.find("--glslc") != args.end()) {
|
||||
GLSLC = args["--glslc"]; // Path to glslc
|
||||
|
||||
if (!std::filesystem::exists(GLSLC) || !std::filesystem::is_regular_file(GLSLC)) {
|
||||
std::cerr << "Error: glslc not found at " << GLSLC << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
if (args.find("--source") != args.end()) {
|
||||
input_filepath = args["--source"]; // The shader source file to compile
|
||||
|
||||
+154
-5
@@ -935,6 +935,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"COS",
|
||||
"SUM",
|
||||
"SUM_ROWS",
|
||||
"CUMSUM",
|
||||
"MEAN",
|
||||
"ARGMAX",
|
||||
"COUNT_EQUAL",
|
||||
@@ -990,6 +991,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"TIMESTEP_EMBEDDING",
|
||||
"ARGSORT",
|
||||
"LEAKY_RELU",
|
||||
"TRI",
|
||||
"FILL",
|
||||
|
||||
"FLASH_ATTN_EXT",
|
||||
"FLASH_ATTN_BACK",
|
||||
@@ -1002,6 +1005,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"RWKV_WKV6",
|
||||
"GATED_LINEAR_ATTN",
|
||||
"RWKV_WKV7",
|
||||
"SOLVE_TRI",
|
||||
|
||||
"UNARY",
|
||||
|
||||
@@ -1019,7 +1023,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1039,6 +1043,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cos(x)",
|
||||
"Σx",
|
||||
"Σx_k",
|
||||
"cumsum(x)",
|
||||
"Σx/n",
|
||||
"argmax(x)",
|
||||
"count_equal(x)",
|
||||
@@ -1094,6 +1099,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"timestep_embedding(timesteps, dim, max_period)",
|
||||
"argsort(x)",
|
||||
"leaky_relu(x)",
|
||||
"tri(x)",
|
||||
"fill(x, c)",
|
||||
|
||||
"flash_attn_ext(x)",
|
||||
"flash_attn_back(x)",
|
||||
@@ -1106,6 +1113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||
"gated_linear_attn(k, v, q, gate, s)",
|
||||
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
||||
"A X = B, A triangular, solve X",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
@@ -1123,7 +1131,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -1142,6 +1150,8 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||
"HARDSWISH",
|
||||
"HARDSIGMOID",
|
||||
"EXP",
|
||||
"EXPM1",
|
||||
"SOFTPLUS",
|
||||
"GELU_ERF",
|
||||
"XIELU",
|
||||
"FLOOR",
|
||||
@@ -1150,7 +1160,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||
"TRUNC",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20");
|
||||
static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22");
|
||||
|
||||
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
||||
"REGLU",
|
||||
@@ -2258,6 +2268,30 @@ struct ggml_tensor * ggml_log_inplace(
|
||||
return ggml_log_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_expm1(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_EXPM1);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_expm1_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXPM1);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_softplus(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_softplus_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS);
|
||||
}
|
||||
|
||||
// ggml_sin
|
||||
|
||||
static struct ggml_tensor * ggml_sin_impl(
|
||||
@@ -2341,6 +2375,21 @@ struct ggml_tensor * ggml_sum_rows(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_cumsum
|
||||
|
||||
struct ggml_tensor * ggml_cumsum(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_CUMSUM;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_mean
|
||||
|
||||
struct ggml_tensor * ggml_mean(
|
||||
@@ -2668,8 +2717,8 @@ struct ggml_tensor * ggml_xielu(
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
|
||||
ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
|
||||
ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
|
||||
ggml_set_op_params_f32(result, 1, beta + ggml_compute_softplus_f32(alpha_n));
|
||||
ggml_set_op_params_f32(result, 2, ggml_compute_softplus_f32(alpha_p));
|
||||
ggml_set_op_params_f32(result, 3, beta);
|
||||
ggml_set_op_params_f32(result, 4, eps);
|
||||
|
||||
@@ -5028,6 +5077,61 @@ struct ggml_tensor * ggml_timestep_embedding(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_tri
|
||||
|
||||
struct ggml_tensor * ggml_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_tri_type type) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(a->ne[0] == a->ne[1]);
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, type);
|
||||
|
||||
result->op = GGML_OP_TRI;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_fill
|
||||
|
||||
static struct ggml_tensor * ggml_fill_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, c);
|
||||
|
||||
result->op = GGML_OP_FILL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_fill(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c) {
|
||||
return ggml_fill_impl(ctx, a, c, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_fill_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c) {
|
||||
return ggml_fill_impl(ctx, a, c, true);
|
||||
}
|
||||
|
||||
// ggml_argsort
|
||||
|
||||
struct ggml_tensor * ggml_argsort(
|
||||
@@ -5882,6 +5986,41 @@ struct ggml_tensor * ggml_opt_step_sgd(
|
||||
return result;
|
||||
}
|
||||
|
||||
// solve_tri
|
||||
|
||||
struct ggml_tensor * ggml_solve_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
bool left,
|
||||
bool lower,
|
||||
bool uni) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
|
||||
// A must be square and lower diagonal
|
||||
GGML_ASSERT(a->ne[0] == a->ne[1]);
|
||||
// B must have same outer dimension as A
|
||||
GGML_ASSERT(a->ne[1] == b->ne[1]);
|
||||
|
||||
// batch dimensions must be equal
|
||||
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
||||
GGML_ASSERT(a->ne[3] == b->ne[3]);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(ggml_is_contiguous(b));
|
||||
|
||||
GGML_ASSERT(lower && left && !uni); // TODO: support other variants
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
|
||||
|
||||
result->op = GGML_OP_SOLVE_TRI;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
||||
@@ -6454,6 +6593,16 @@ static void ggml_compute_backward(
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXPM1: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_exp(ctx, src0)));
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_SOFTPLUS: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sigmoid(ctx, src0)));
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
|
||||
__func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
|
||||
|
||||
@@ -409,6 +409,7 @@ class MODEL_ARCH(IntEnum):
|
||||
BAILINGMOE2 = auto()
|
||||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
AFMOE = auto()
|
||||
ERNIE4_5 = auto()
|
||||
ERNIE4_5_MOE = auto()
|
||||
HUNYUAN_MOE = auto()
|
||||
@@ -464,6 +465,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_POST_NORM = auto()
|
||||
ATTN_ROT_EMBD = auto()
|
||||
ATTN_SINKS = auto()
|
||||
ATTN_GATE = auto()
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
@@ -776,6 +778,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
|
||||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
MODEL_ARCH.AFMOE: "afmoe",
|
||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
|
||||
MODEL_ARCH.FALCON_H1: "falcon-h1",
|
||||
@@ -828,6 +831,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
|
||||
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
||||
@@ -2693,6 +2697,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.AFMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_GATE,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
MODEL_ARCH.ERNIE4_5: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@@ -314,6 +314,10 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.sinks", # openai-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_GATE: (
|
||||
"model.layers.{bid}.self_attn.gate_proj", # afmoe
|
||||
),
|
||||
|
||||
# Feed-forward norm
|
||||
MODEL_TENSOR.FFN_NORM: (
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||
@@ -340,11 +344,12 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.feedforward_layernorm", # apertus
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
# Pre feed-forward norm
|
||||
MODEL_TENSOR.FFN_PRE_NORM: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
|
||||
"model.layers.{bid}.pre_ff_layernorm.weight",
|
||||
"model.layers.{bid}.pre_mlp_layernorm", # afmoe
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
@@ -370,6 +375,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
||||
"model.layers.{bid}.mlp.router.gate", # afmoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
@@ -380,6 +386,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
||||
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
||||
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
|
||||
"model.layers.{bid}.mlp.expert_bias", # afmoe
|
||||
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
|
||||
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
|
||||
),
|
||||
|
||||
@@ -35,6 +35,7 @@ add_library(llama
|
||||
unicode-data.cpp
|
||||
unicode.cpp
|
||||
unicode.h
|
||||
models/afmoe.cpp
|
||||
models/apertus.cpp
|
||||
models/arcee.cpp
|
||||
models/arctic.cpp
|
||||
|
||||
@@ -90,6 +90,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_AFMOE, "afmoe" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
@@ -333,6 +334,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_AFMOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA4,
|
||||
{
|
||||
@@ -2444,6 +2475,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
||||
@@ -94,6 +94,7 @@ enum llm_arch {
|
||||
LLM_ARCH_BAILINGMOE2,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_AFMOE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_ERNIE4_5_MOE,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
@@ -312,6 +313,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_ATTN_SINKS,
|
||||
LLM_TENSOR_ATTN_GATE,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
|
||||
@@ -84,6 +84,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_15B: return "15B";
|
||||
case LLM_TYPE_16B: return "16B";
|
||||
case LLM_TYPE_20B: return "20B";
|
||||
case LLM_TYPE_26B: return "26B";
|
||||
case LLM_TYPE_27B: return "27B";
|
||||
case LLM_TYPE_30B: return "30B";
|
||||
case LLM_TYPE_32B: return "32B";
|
||||
@@ -695,6 +696,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
|
||||
// Set up interleaved sliding window attention (ISWA)
|
||||
// Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4)
|
||||
if (hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.set_swa_pattern(4);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
}
|
||||
|
||||
// Default to sigmoid if not set
|
||||
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 56: type = LLM_TYPE_6B; break;
|
||||
case 32: type = LLM_TYPE_26B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DECI:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@@ -5749,6 +5781,71 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
const int64_t n_expert_shared = hparams.n_expert_shared;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
// dual attention normalization
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// attention projections
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
// Q/K normalization
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
// attention gating
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
|
||||
// dual ffn normalization
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) {
|
||||
// MoE layers
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
|
||||
|
||||
// grouped expert weights
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
// shared expert
|
||||
if (n_expert_shared > 0) {
|
||||
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
}
|
||||
} else {
|
||||
// Dense layers
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_ERNIE4_5_MOE:
|
||||
{
|
||||
@@ -7243,6 +7340,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_arcee>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_afmoe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
{
|
||||
llm = std::make_unique<llm_build_ernie4_5>(*this, params);
|
||||
@@ -7528,6 +7629,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_MINIMAX_M2:
|
||||
case LLM_ARCH_COGVLM:
|
||||
case LLM_ARCH_PANGU_EMBED:
|
||||
case LLM_ARCH_AFMOE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
@@ -76,6 +76,7 @@ enum llm_type {
|
||||
LLM_TYPE_15B,
|
||||
LLM_TYPE_16B,
|
||||
LLM_TYPE_20B,
|
||||
LLM_TYPE_26B,
|
||||
LLM_TYPE_27B,
|
||||
LLM_TYPE_30B,
|
||||
LLM_TYPE_32B,
|
||||
@@ -234,6 +235,7 @@ struct llama_layer {
|
||||
struct ggml_tensor * wk_enc = nullptr;
|
||||
struct ggml_tensor * wv_enc = nullptr;
|
||||
struct ggml_tensor * wo_enc = nullptr;
|
||||
struct ggml_tensor * wqkv_gate = nullptr;
|
||||
|
||||
// attention bias
|
||||
struct ggml_tensor * bq = nullptr;
|
||||
|
||||
+10
-5
@@ -4,6 +4,7 @@
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
@@ -1625,10 +1626,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
auto * ctx = new llama_sampler_grammar;
|
||||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
std::string trigger_pattern;
|
||||
llama_grammar * grammar = nullptr;
|
||||
// TODO: remove trigger_words support.
|
||||
if (trigger_words != nullptr && num_trigger_words > 0) {
|
||||
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
||||
std::string trigger_pattern("[\\s\\S]*?(");
|
||||
trigger_pattern = "[\\s\\S]*?(";
|
||||
for (size_t i = 0; i < num_trigger_words; ++i) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
if (i > 0) {
|
||||
@@ -1637,15 +1640,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
||||
}
|
||||
trigger_pattern += ")[\\s\\S]*";
|
||||
const auto * trigger_pattern_c = trigger_pattern.c_str();
|
||||
trigger_patterns = &trigger_pattern_c;
|
||||
num_trigger_patterns = 1;
|
||||
|
||||
std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
|
||||
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
|
||||
} else {
|
||||
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
|
||||
}
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||
/* .grammar = */ grammar,
|
||||
};
|
||||
if (!ctx->grammar) {
|
||||
delete ctx;
|
||||
|
||||
@@ -443,6 +443,17 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_AFMOE:
|
||||
regex_exprs = {
|
||||
// Digit handling - uses custom implementation in unicode.cpp
|
||||
// Groups digits with leading 1-2 based on total length modulo 3
|
||||
"\\p{AFMoE_digits}",
|
||||
// CJK and Asian scripts (using direct Unicode literals)
|
||||
"[一-鿿㐀-䶿豈--ゟ゠-ヿ・-゚⼀-เ--ក-က-႟ꩠ-ꩿꧠ-가-ᄀ-ᇿ]+",
|
||||
// Main BPE pattern
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@@ -1993,6 +2004,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "grok-2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "afmoe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_AFMOE;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "minimax-m2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
|
||||
|
||||
@@ -50,6 +50,7 @@ enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
|
||||
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
|
||||
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// MuP scaling: embeddings * sqrt(hidden_size)
|
||||
// mup_enabled = true, hidden_size = 1024, scale = 32.0
|
||||
inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd)));
|
||||
cb(inpL, "inp_embd_scaled", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// dual attention normalization (pre)
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * attn_inp = cur; // save input for gate computation
|
||||
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// compute gate from input
|
||||
ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp);
|
||||
cb(gate, "attn_gate_proj", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Q/K normalization
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
// RoPE only for sliding_attention layers
|
||||
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
|
||||
((il + 1) % hparams.n_no_rope_layer_step) != 0;
|
||||
if (use_rope) {
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
}
|
||||
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
NULL, NULL, // wo will be applied after gating
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
|
||||
// attention gating: attn_out * sigmoid(gate) BEFORE o_proj
|
||||
gate = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate, "attn_gate_sig", il);
|
||||
cur = ggml_mul(ctx0, cur, gate);
|
||||
cb(cur, "attn_gated", il);
|
||||
|
||||
// now apply output projection
|
||||
cur = build_lora_mm(model.layers[il].wo, cur);
|
||||
cb(cur, "attn_o_proj", il);
|
||||
}
|
||||
|
||||
// dual attention normalization (post)
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// dual ffn normalization (pre)
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// MoE or dense FFN
|
||||
if ((uint32_t)il >= hparams.n_layer_dense_lead) {
|
||||
// MoE layer with sigmoid routing, normalization, and scaling
|
||||
ggml_tensor * moe_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU,
|
||||
hparams.expert_weights_norm, // norm_w (route_norm=True)
|
||||
hparams.expert_weights_scale, // scale_w
|
||||
hparams.expert_weights_scale, // w_scale (route_scale=2.826)
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// shared expert
|
||||
if (hparams.n_expert_shared > 0) {
|
||||
ggml_tensor * ffn_shexp = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = moe_out;
|
||||
}
|
||||
} else {
|
||||
// dense layer
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
// dual ffn normalization (post)
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
@@ -57,6 +57,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
int il) const;
|
||||
};
|
||||
|
||||
struct llm_build_afmoe : public llm_graph_context {
|
||||
llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_apertus : public llm_graph_context {
|
||||
llm_build_apertus(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
@@ -729,6 +729,80 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// AFMOE digit handling: splits digits with leading 1-2 based on total length modulo 3
|
||||
static std::vector<size_t> unicode_regex_split_custom_afmoe(const std::string & text, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
bpe_offsets.reserve(offsets.size());
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
const size_t offset_ini = start;
|
||||
const size_t offset_end = start + offset;
|
||||
assert(offset_end <= cpts.size());
|
||||
start = offset_end;
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
auto _add_token = [&] (const size_t end) -> size_t {
|
||||
assert(_prev_end <= end && end <= offset_end);
|
||||
size_t len = end - _prev_end;
|
||||
if (len > 0) {
|
||||
bpe_offsets.push_back(len);
|
||||
}
|
||||
_prev_end = end;
|
||||
return len;
|
||||
};
|
||||
|
||||
for (size_t pos = offset_ini; pos < offset_end; ) {
|
||||
const auto flags = _get_flags(pos);
|
||||
|
||||
// Handle digit sequences with special splitting logic
|
||||
if (flags.is_number) {
|
||||
size_t digit_start = pos;
|
||||
size_t digit_count = 0;
|
||||
|
||||
// Count consecutive digits
|
||||
while (_get_flags(pos).is_number && pos < offset_end) {
|
||||
digit_count++;
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Split based on total length modulo 3
|
||||
size_t remainder = digit_count % 3;
|
||||
size_t current = digit_start;
|
||||
|
||||
// Emit leading 1-2 digits if needed
|
||||
if (remainder > 0) {
|
||||
_add_token(current + remainder);
|
||||
current += remainder;
|
||||
}
|
||||
|
||||
// Emit groups of 3
|
||||
while (current < digit_start + digit_count) {
|
||||
_add_token(current + 3);
|
||||
current += 3;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// For non-digits, just move forward
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Add any remaining content
|
||||
if (_prev_end < offset_end) {
|
||||
_add_token(offset_end);
|
||||
}
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
|
||||
@@ -742,6 +816,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
||||
} else if (regex_expr == "\\p{Han}+") {
|
||||
// K2's first pattern - handle all K2 patterns together
|
||||
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
|
||||
} else if (regex_expr == "\\p{AFMoE_digits}") {
|
||||
// AFMOE digit pattern - use custom implementation for proper splitting
|
||||
bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
|
||||
+230
-3
@@ -175,6 +175,38 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m
|
||||
ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t));
|
||||
}
|
||||
|
||||
// generate a lower triangular matrix
|
||||
static void init_tensor_tril(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(tensor->ne[0] == tensor->ne[1]);
|
||||
|
||||
GGML_TENSOR_LOCALS(int32_t, ne, tensor, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nb, tensor, nb);
|
||||
|
||||
std::vector<float> data_f32(ne0*ne1*ne2*ne3);
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<float> dis(min, max);
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
||||
int64_t idx = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3) / sizeof(float);
|
||||
if (i0 <= i1) {
|
||||
data_f32[idx] = dis(gen);
|
||||
} else {
|
||||
data_f32[idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(tensor, data_f32.data(), 0, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
static std::vector<float> tensor_to_float(const ggml_tensor * t) {
|
||||
std::vector<float> tv;
|
||||
tv.reserve(ggml_nelements(t));
|
||||
@@ -1804,7 +1836,8 @@ struct test_unary : public test_case {
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const bool grad_supported = op == GGML_UNARY_OP_ABS || op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_NEG ||
|
||||
op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU;
|
||||
op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU ||
|
||||
op == GGML_UNARY_OP_EXPM1 || op == GGML_UNARY_OP_SOFTPLUS;
|
||||
|
||||
ggml_tensor * a;
|
||||
if (v & 1) {
|
||||
@@ -2779,7 +2812,7 @@ struct test_bin_bcast : public test_case {
|
||||
const std::array<int, 4> nr;
|
||||
int nf; // number of fused ops, nf == 1 -> single op (no fusion)
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
bool run_whole_graph() override { return nf > 1; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, nr, nf);
|
||||
@@ -5395,6 +5428,7 @@ struct test_pad : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_PAD (with extension)
|
||||
struct test_pad_ext : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
@@ -5802,6 +5836,7 @@ struct test_opt_step_adamw : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_OPT_STEP_SGD
|
||||
struct test_opt_step_sgd : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
@@ -5841,6 +5876,170 @@ struct test_opt_step_sgd : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CUMSUM
|
||||
struct test_cumsum : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR2(type, ne); }
|
||||
|
||||
test_cumsum(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_cumsum(ctx, a);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_XIELU
|
||||
struct test_xielu : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR2(type, ne); }
|
||||
|
||||
test_xielu(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
float alpha_n = 4.0f;
|
||||
float alpha_p = 20.0f;
|
||||
float beta = 0.5f;
|
||||
float eps = 0.0000001f;
|
||||
|
||||
ggml_tensor * out = ggml_xielu(ctx, a, alpha_n, alpha_p, beta, eps);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_TRI
|
||||
struct test_tri : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const ggml_tri_type tri_type;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR3(type, ne, tri_type); }
|
||||
|
||||
test_tri(ggml_tri_type tri_type, ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = { 10, 10, 4, 3 })
|
||||
: type(type), ne(ne), tri_type(tri_type) {
|
||||
GGML_ASSERT(ne[0] == ne[1]);
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_tri(ctx, a, tri_type);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_FILL
|
||||
struct test_fill : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
float c;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR3(type, ne, c); }
|
||||
|
||||
test_fill(float c, ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = { 10, 10, 4, 3 })
|
||||
: type(type), ne(ne), c(c) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_fill(ctx, a, c);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SOLVE_TRI
|
||||
struct test_solve_tri : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_lhs;
|
||||
const std::array<int64_t, 4> ne_rhs;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }
|
||||
|
||||
test_solve_tri(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },
|
||||
std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }
|
||||
)
|
||||
: type(type), ne_lhs(ne_lhs), ne_rhs(ne_rhs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne_lhs[0], ne_lhs[1], ne_lhs[2], ne_lhs[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne_rhs[0], ne_rhs[1], ne_rhs[2], ne_rhs[3]);
|
||||
ggml_set_param(b);
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
ggml_tensor * out = ggml_solve_tri(ctx, a, b, true, true, false);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (strcmp(t->name, "a") == 0) {
|
||||
// note: avoid zeros in the diagonal
|
||||
init_tensor_tril(t, 0.1, 1.0f);
|
||||
} else {
|
||||
init_tensor_uniform(t, -1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
enum llm_norm_type {
|
||||
LLM_NORM,
|
||||
LLM_NORM_RMS,
|
||||
@@ -6282,6 +6481,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
for (int v : {0, 1}) {
|
||||
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
|
||||
if (op == GGML_UNARY_OP_XIELU) {
|
||||
continue; // need extra params, separate test
|
||||
}
|
||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
|
||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
|
||||
}
|
||||
@@ -7290,8 +7492,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||
}
|
||||
|
||||
@@ -7339,6 +7546,26 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_arange());
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
test_cases.emplace_back(new test_cumsum());
|
||||
|
||||
test_cases.emplace_back(new test_xielu());
|
||||
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER));
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER));
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG));
|
||||
|
||||
test_cases.emplace_back(new test_fill(0.0f));
|
||||
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
|
||||
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
|
||||
|
||||
test_cases.emplace_back(new test_solve_tri());
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 17, 17, 2, 4 }, { 9, 17, 2, 4 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
|
||||
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
|
||||
|
||||
+5
-12
@@ -224,7 +224,6 @@ static void clip_log_callback_default(enum ggml_log_level level, const char * te
|
||||
}
|
||||
|
||||
struct clip_logger_state {
|
||||
ggml_log_level verbosity_thold;
|
||||
ggml_log_callback log_callback;
|
||||
void * log_callback_user_data;
|
||||
};
|
||||
@@ -258,17 +257,11 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
#define LOG_TMPL(level, ...) \
|
||||
do { \
|
||||
if ((level) >= g_logger_state.verbosity_thold) { \
|
||||
clip_log_internal((level), __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
|
||||
#define LOG_INF(...) clip_log_internal(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) clip_log_internal(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) clip_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define LOG_DBG(...) clip_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_CNT(...) clip_log_internal(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
|
||||
|
||||
//
|
||||
// cpp wrappers
|
||||
|
||||
+1
-3
@@ -24,8 +24,7 @@
|
||||
#include <array>
|
||||
#include <functional>
|
||||
|
||||
// TODO: allow to pass callback from user code
|
||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||
struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
|
||||
|
||||
enum ffn_op_type {
|
||||
FFN_GELU,
|
||||
@@ -3507,7 +3506,6 @@ struct clip_model_loader {
|
||||
};
|
||||
|
||||
struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
|
||||
g_logger_state.verbosity_thold = ctx_params.verbosity;
|
||||
clip_ctx * ctx_vision = nullptr;
|
||||
clip_ctx * ctx_audio = nullptr;
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ enum clip_flash_attn_type {
|
||||
|
||||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
enum ggml_log_level verbosity;
|
||||
enum clip_flash_attn_type flash_attn_type;
|
||||
int image_min_tokens;
|
||||
int image_max_tokens;
|
||||
|
||||
@@ -135,7 +135,6 @@ struct mtmd_cli_context {
|
||||
mparams.use_gpu = params.mmproj_use_gpu;
|
||||
mparams.print_timings = true;
|
||||
mparams.n_threads = params.cpuparams.n_threads;
|
||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params.flash_attn_type;
|
||||
mparams.image_min_tokens = params.image_min_tokens;
|
||||
mparams.image_max_tokens = params.image_max_tokens;
|
||||
@@ -277,6 +276,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
common_init();
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
|
||||
if (params.mmproj.path.empty()) {
|
||||
show_additional_info(argc, argv);
|
||||
|
||||
@@ -32,8 +32,65 @@
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb/stb_image.h"
|
||||
|
||||
#define LOG_INF(...) fprintf(stdout, __VA_ARGS__)
|
||||
#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__)
|
||||
//
|
||||
// internal logging functions
|
||||
//
|
||||
|
||||
struct mtmd_helper_logger {
|
||||
ggml_log_callback default_callback = [](ggml_log_level level, const char * text, void * user_data) {
|
||||
(void) level;
|
||||
(void) user_data;
|
||||
fputs(text, stderr);
|
||||
fflush(stderr);
|
||||
};
|
||||
|
||||
ggml_log_callback log_callback = default_callback;
|
||||
void * log_callback_user_data;
|
||||
|
||||
void log_v(enum ggml_log_level level, const char * format, va_list args) {
|
||||
if (format == NULL) {
|
||||
return;
|
||||
}
|
||||
va_list args_copy;
|
||||
va_copy(args_copy, args);
|
||||
char buffer[128];
|
||||
int len = vsnprintf(buffer, 128, format, args);
|
||||
if (len < 128) {
|
||||
log_callback(level, buffer, log_callback_user_data);
|
||||
} else {
|
||||
char * buffer2 = (char *) calloc(len + 1, sizeof(char));
|
||||
vsnprintf(buffer2, len + 1, format, args_copy);
|
||||
buffer2[len] = 0;
|
||||
log_callback(level, buffer2, log_callback_user_data);
|
||||
free(buffer2);
|
||||
}
|
||||
va_end(args_copy);
|
||||
}
|
||||
|
||||
void log(enum ggml_log_level level, const char * format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
log_v(level, format, args);
|
||||
va_end(args);
|
||||
}
|
||||
} g_logger;
|
||||
|
||||
#define LOG_INF(...) g_logger.log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) g_logger.log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) g_logger.log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
|
||||
void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||
if (log_callback == nullptr) {
|
||||
log_callback = g_logger.default_callback;
|
||||
}
|
||||
g_logger.log_callback = log_callback;
|
||||
g_logger.log_callback_user_data = user_data;
|
||||
mtmd_log_set(log_callback, user_data);
|
||||
}
|
||||
|
||||
//
|
||||
// helper functions
|
||||
//
|
||||
|
||||
size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
|
||||
size_t n_tokens = 0;
|
||||
@@ -325,7 +382,7 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
|
||||
llama_pos * new_n_past) {
|
||||
size_t n_chunks = mtmd_input_chunks_size(chunks);
|
||||
if (n_chunks == 0) {
|
||||
LOG_ERR("no chunks to eval\n");
|
||||
LOG_WRN("no chunks to eval\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,11 @@ extern "C" {
|
||||
// BREAKING CHANGES are expected.
|
||||
//
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
// Note: this also call mtmd_log_set() internally
|
||||
MTMD_API void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
// helper function to construct a mtmd_bitmap from a file
|
||||
// it calls mtmd_helper_bitmap_init_from_buf() internally
|
||||
// returns nullptr on failure
|
||||
|
||||
+5
-2
@@ -105,7 +105,6 @@ mtmd_context_params mtmd_context_params_default() {
|
||||
/* use_gpu */ true,
|
||||
/* print_timings */ true,
|
||||
/* n_threads */ 4,
|
||||
/* verbosity */ GGML_LOG_LEVEL_INFO,
|
||||
/* image_marker */ MTMD_DEFAULT_IMAGE_MARKER,
|
||||
/* media_marker */ mtmd_default_marker(),
|
||||
/* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO,
|
||||
@@ -175,7 +174,6 @@ struct mtmd_context {
|
||||
|
||||
clip_context_params ctx_clip_params {
|
||||
/* use_gpu */ ctx_params.use_gpu,
|
||||
/* verbosity */ ctx_params.verbosity,
|
||||
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
|
||||
/* image_min_tokens */ ctx_params.image_min_tokens,
|
||||
/* image_max_tokens */ ctx_params.image_max_tokens,
|
||||
@@ -1096,3 +1094,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
void mtmd_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||
g_logger_state.log_callback = log_callback ? log_callback : clip_log_callback_default;
|
||||
g_logger_state.log_callback_user_data = user_data;
|
||||
}
|
||||
|
||||
+4
-1
@@ -79,7 +79,6 @@ struct mtmd_context_params {
|
||||
bool use_gpu;
|
||||
bool print_timings;
|
||||
int n_threads;
|
||||
enum ggml_log_level verbosity;
|
||||
const char * image_marker; // deprecated, use media_marker instead
|
||||
const char * media_marker;
|
||||
enum llama_flash_attn_type flash_attn_type;
|
||||
@@ -215,6 +214,10 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
||||
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
|
||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
MTMD_API void mtmd_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
/////////////////////////////////////////
|
||||
|
||||
// test function, to be used in test-mtmd-c-api.c
|
||||
|
||||
Binary file not shown.
+55
-50
@@ -2454,11 +2454,12 @@ struct server_context {
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
@@ -3591,13 +3592,13 @@ struct server_context {
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.is_processing()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check if we can batch this slot with the previous one
|
||||
if (slot.is_processing()) {
|
||||
if (!slot_batched) {
|
||||
slot_batched = &slot;
|
||||
} else if (!slot_batched->can_batch_with(slot)) {
|
||||
continue;
|
||||
}
|
||||
if (slot_batched && !slot_batched->can_batch_with(slot)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// this slot still has a prompt to be processed
|
||||
@@ -4028,6 +4029,10 @@ struct server_context {
|
||||
}
|
||||
}
|
||||
|
||||
if (!slot_batched) {
|
||||
slot_batched = &slot;
|
||||
}
|
||||
|
||||
if (batch.n_tokens >= n_batch) {
|
||||
break;
|
||||
}
|
||||
@@ -4431,7 +4436,7 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
|
||||
SRV_DBG("response: %s\n", res.body.c_str());
|
||||
}
|
||||
|
||||
static void res_error(httplib::Response & res, const json & error_data) {
|
||||
static void res_err(httplib::Response & res, const json & error_data) {
|
||||
json final_response {{"error", error_data}};
|
||||
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
@@ -4524,7 +4529,7 @@ int main(int argc, char ** argv) {
|
||||
try {
|
||||
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
||||
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
|
||||
res_error(res, formatted_error);
|
||||
res_err(res, formatted_error);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
|
||||
}
|
||||
@@ -4532,9 +4537,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 404) {
|
||||
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
}
|
||||
// for other error codes, we skip processing here because it's already done by res_error()
|
||||
// for other error codes, we skip processing here because it's already done by res_err()
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
@@ -4591,7 +4596,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
res_err(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
|
||||
LOG_WRN("Unauthorized: Invalid API Key\n");
|
||||
|
||||
@@ -4609,7 +4614,7 @@ int main(int argc, char ** argv) {
|
||||
// allow the models endpoint to be accessed during loading
|
||||
return true;
|
||||
} else {
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -4648,7 +4653,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!params.endpoint_slots) {
|
||||
res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4666,7 +4671,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4677,7 +4682,7 @@ int main(int argc, char ** argv) {
|
||||
// optionally return "fail_on_no_slot" error
|
||||
if (req.has_param("fail_on_no_slot")) {
|
||||
if (res_task->n_idle_slots == 0) {
|
||||
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
||||
res_err(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -4687,7 +4692,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!params.endpoint_metrics) {
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4705,7 +4710,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4790,7 +4795,7 @@ int main(int argc, char ** argv) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
@@ -4811,7 +4816,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4822,7 +4827,7 @@ int main(int argc, char ** argv) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
@@ -4843,7 +4848,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4866,7 +4871,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4876,7 +4881,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
if (params.slot_save_path.empty()) {
|
||||
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4886,7 +4891,7 @@ int main(int argc, char ** argv) {
|
||||
try {
|
||||
id_slot = std::stoi(id_slot_str);
|
||||
} catch (const std::exception &) {
|
||||
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4899,7 +4904,7 @@ int main(int argc, char ** argv) {
|
||||
} else if (action == "erase") {
|
||||
handle_slots_erase(req, res, id_slot);
|
||||
} else {
|
||||
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4947,7 +4952,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!ctx_server.params_base.endpoint_props) {
|
||||
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5044,7 +5049,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
rd->post_tasks(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5056,7 +5061,7 @@ int main(int argc, char ** argv) {
|
||||
if (all_results.is_terminated) {
|
||||
return; // connection is closed
|
||||
} else if (all_results.error) {
|
||||
res_error(res, all_results.error->to_json());
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
} else {
|
||||
json arr = json::array();
|
||||
@@ -5076,7 +5081,7 @@ int main(int argc, char ** argv) {
|
||||
if (first_result == nullptr) {
|
||||
return; // connection is closed
|
||||
} else if (first_result->is_error()) {
|
||||
res_error(res, first_result->to_json());
|
||||
res_err(res, first_result->to_json());
|
||||
return;
|
||||
} else {
|
||||
GGML_ASSERT(
|
||||
@@ -5183,7 +5188,7 @@ int main(int argc, char ** argv) {
|
||||
err += "middle token is missing. ";
|
||||
}
|
||||
if (!err.empty()) {
|
||||
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5192,20 +5197,20 @@ int main(int argc, char ** argv) {
|
||||
// validate input
|
||||
if (data.contains("prompt") && !data.at("prompt").is_string()) {
|
||||
// prompt is optional
|
||||
res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (!data.contains("input_prefix")) {
|
||||
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (!data.contains("input_suffix")) {
|
||||
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||
// input_extra is optional
|
||||
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5213,12 +5218,12 @@ int main(int argc, char ** argv) {
|
||||
for (const auto & chunk : input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
||||
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
// filename is optional
|
||||
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
||||
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -5380,12 +5385,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
|
||||
if (!ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5399,7 +5404,7 @@ int main(int argc, char ** argv) {
|
||||
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
||||
prompt = body.at("content");
|
||||
} else {
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5409,7 +5414,7 @@ int main(int argc, char ** argv) {
|
||||
if (format == "base64") {
|
||||
use_base64 = true;
|
||||
} else if (format != "float") {
|
||||
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -5418,7 +5423,7 @@ int main(int argc, char ** argv) {
|
||||
for (const auto & tokens : tokenized_prompts) {
|
||||
// this check is necessary for models that do not add BOS token to the input
|
||||
if (tokens.empty()) {
|
||||
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -5459,7 +5464,7 @@ int main(int argc, char ** argv) {
|
||||
if (all_results.is_terminated) {
|
||||
return; // connection is closed
|
||||
} else if (all_results.error) {
|
||||
res_error(res, all_results.error->to_json());
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
} else {
|
||||
for (auto & res : all_results.results) {
|
||||
@@ -5485,7 +5490,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
|
||||
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5500,18 +5505,18 @@ int main(int argc, char ** argv) {
|
||||
if (body.count("query") == 1) {
|
||||
query = body.at("query");
|
||||
if (!query.is_string()) {
|
||||
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> documents = json_value(body, "documents",
|
||||
json_value(body, "texts", std::vector<std::string>()));
|
||||
if (documents.empty()) {
|
||||
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5541,7 +5546,7 @@ int main(int argc, char ** argv) {
|
||||
if (all_results.is_terminated) {
|
||||
return; // connection is closed
|
||||
} else if (all_results.error) {
|
||||
res_error(res, all_results.error->to_json());
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
} else {
|
||||
for (auto & res : all_results.results) {
|
||||
@@ -5594,7 +5599,7 @@ int main(int argc, char ** argv) {
|
||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
const json body = json::parse(req.body);
|
||||
if (!body.is_array()) {
|
||||
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5612,7 +5617,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
+19
-29
@@ -1,6 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { RemoveButton } from '$lib/components/app';
|
||||
import { formatFileSize, getFileTypeLabel, getPreviewText } from '$lib/utils/file-preview';
|
||||
import { FileTypeCategory, MimeTypeText } from '$lib/enums/files';
|
||||
|
||||
@@ -66,17 +65,15 @@
|
||||
</button>
|
||||
{:else}
|
||||
<!-- Non-readonly mode (ChatForm) -->
|
||||
<div class="relative rounded-lg border border-border bg-muted p-3 {className} w-64">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="absolute top-2 right-2 h-6 w-6 bg-white/20 p-0 hover:bg-white/30"
|
||||
onclick={() => onRemove?.(id)}
|
||||
aria-label="Remove file"
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
<button
|
||||
class="group relative rounded-lg border border-border bg-muted p-3 {className} {textContent
|
||||
? 'max-h-24 max-w-72'
|
||||
: 'max-w-36'} cursor-pointer text-left"
|
||||
onclick={onClick}
|
||||
>
|
||||
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
|
||||
<RemoveButton {id} {onRemove} />
|
||||
</div>
|
||||
|
||||
<div class="pr-8">
|
||||
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
|
||||
@@ -85,7 +82,7 @@
|
||||
<div class="relative">
|
||||
<div
|
||||
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
|
||||
style="max-height: 3.6em; line-height: 1.2em;"
|
||||
style="max-height: 3rem; line-height: 1.2em;"
|
||||
>
|
||||
{getPreviewText(textContent)}
|
||||
</div>
|
||||
@@ -98,11 +95,11 @@
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
{/if}
|
||||
{:else}
|
||||
<button
|
||||
class="flex items-center gap-2 gap-3 rounded-lg border border-border bg-muted p-3 {className}"
|
||||
class="group flex items-center gap-3 rounded-lg border border-border bg-muted p-3 {className} relative"
|
||||
onclick={onClick}
|
||||
>
|
||||
<div
|
||||
@@ -112,7 +109,9 @@
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<span class="max-w-36 truncate text-sm font-medium text-foreground md:max-w-72">
|
||||
<span
|
||||
class="max-w-24 truncate text-sm font-medium text-foreground group-hover:pr-6 md:max-w-32"
|
||||
>
|
||||
{name}
|
||||
</span>
|
||||
|
||||
@@ -122,18 +121,9 @@
|
||||
</div>
|
||||
|
||||
{#if !readonly}
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 p-0"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(id);
|
||||
}}
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
|
||||
<RemoveButton {id} {onRemove} />
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
+5
-14
@@ -1,6 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { RemoveButton } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
@@ -26,12 +25,12 @@
|
||||
class: className = '',
|
||||
// Default to small size for form previews
|
||||
width = 'w-auto',
|
||||
height = 'h-24',
|
||||
height = 'h-16',
|
||||
imageClass = ''
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative overflow-hidden rounded-lg border border-border bg-muted {className}">
|
||||
<div class="group relative overflow-hidden rounded-lg border border-border bg-muted {className}">
|
||||
{#if onClick}
|
||||
<button
|
||||
type="button"
|
||||
@@ -55,17 +54,9 @@
|
||||
|
||||
{#if !readonly}
|
||||
<div
|
||||
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity hover:opacity-100"
|
||||
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
|
||||
>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 bg-white/20 p-0 text-white hover:bg-white/30"
|
||||
onclick={() => onRemove?.(id)}
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
<RemoveButton {id} {onRemove} class="text-white" />
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
+1
-1
@@ -153,7 +153,7 @@
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Content class="grid max-h-[90vh] max-w-5xl overflow-hidden !p-10 sm:w-auto sm:max-w-6xl">
|
||||
<Dialog.Header class="flex-shrink-0">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<div class="flex items-center gap-3">
|
||||
{#if IconComponent}
|
||||
<IconComponent class="h-5 w-5 text-muted-foreground" />
|
||||
|
||||
+142
-59
@@ -1,11 +1,16 @@
|
||||
<script lang="ts">
|
||||
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
|
||||
import { FileTypeCategory } from '$lib/enums/files';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
|
||||
import ChatAttachmentsViewAllDialog from './ChatAttachmentsViewAllDialog.svelte';
|
||||
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
style?: string;
|
||||
// For ChatMessage - stored attachments
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
@@ -16,10 +21,13 @@
|
||||
imageClass?: string;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
// Limit display to single row with "+ X more" button
|
||||
limitToSingleRow?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
style = '',
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
onFileRemove,
|
||||
@@ -27,36 +35,23 @@
|
||||
// Default to small size for form previews
|
||||
imageClass = '',
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto'
|
||||
imageWidth = 'w-auto',
|
||||
limitToSingleRow = false
|
||||
}: Props = $props();
|
||||
|
||||
let displayItems = $derived(getDisplayItems());
|
||||
|
||||
// Preview dialog state
|
||||
let canScrollLeft = $state(false);
|
||||
let canScrollRight = $state(false);
|
||||
let isScrollable = $state(false);
|
||||
let previewDialogOpen = $state(false);
|
||||
let previewItem = $state<{
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
preview?: string;
|
||||
name?: string;
|
||||
type?: string;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
} | null>(null);
|
||||
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
|
||||
let scrollContainer: HTMLDivElement | undefined = $state();
|
||||
let showViewAll = $derived(limitToSingleRow && displayItems.length > 0 && isScrollable);
|
||||
let viewAllDialogOpen = $state(false);
|
||||
|
||||
function getDisplayItems() {
|
||||
const items: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
size?: number;
|
||||
preview?: string;
|
||||
type: string;
|
||||
isImage: boolean;
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
attachmentIndex?: number;
|
||||
textContent?: string;
|
||||
}> = [];
|
||||
function getDisplayItems(): ChatAttachmentDisplayItem[] {
|
||||
const items: ChatAttachmentDisplayItem[] = [];
|
||||
|
||||
// Add uploaded files (ChatForm)
|
||||
for (const file of uploadedFiles) {
|
||||
@@ -127,14 +122,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
return items;
|
||||
return items.reverse();
|
||||
}
|
||||
|
||||
function openPreview(item: (typeof displayItems)[0], event?: Event) {
|
||||
if (event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
}
|
||||
function openPreview(item: ChatAttachmentDisplayItem, event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
previewItem = {
|
||||
uploadedFile: item.uploadedFile,
|
||||
@@ -147,38 +140,118 @@
|
||||
};
|
||||
previewDialogOpen = true;
|
||||
}
|
||||
|
||||
function scrollLeft(event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * -0.67, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function scrollRight(event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * 0.67, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function updateScrollButtons() {
|
||||
if (!scrollContainer) return;
|
||||
|
||||
const { scrollLeft, scrollWidth, clientWidth } = scrollContainer;
|
||||
|
||||
canScrollLeft = scrollLeft > 0;
|
||||
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1;
|
||||
isScrollable = scrollWidth > clientWidth;
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (scrollContainer && displayItems.length) {
|
||||
scrollContainer.scrollLeft = 0;
|
||||
|
||||
setTimeout(() => {
|
||||
updateScrollButtons();
|
||||
}, 0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if displayItems.length > 0}
|
||||
<div class="flex flex-wrap items-start {readonly ? 'justify-end' : ''} gap-3 {className}">
|
||||
{#each displayItems as item (item.id)}
|
||||
{#if item.isImage && item.preview}
|
||||
<ChatAttachmentImagePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatAttachmentFilePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
type={item.type}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
<div class={className} {style}>
|
||||
<div class="relative">
|
||||
<button
|
||||
class="absolute top-1/2 left-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollLeft
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollLeft}
|
||||
aria-label="Scroll left"
|
||||
>
|
||||
<ChevronLeft class="h-4 w-4" />
|
||||
</button>
|
||||
|
||||
<div
|
||||
class="scrollbar-hide flex items-start gap-3 overflow-x-auto"
|
||||
bind:this={scrollContainer}
|
||||
onscroll={updateScrollButtons}
|
||||
>
|
||||
{#each displayItems as item (item.id)}
|
||||
{#if item.isImage && item.preview}
|
||||
<ChatAttachmentImagePreview
|
||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatAttachmentFilePreview
|
||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
type={item.type}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<button
|
||||
class="absolute top-1/2 right-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollRight
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollRight}
|
||||
aria-label="Scroll right"
|
||||
>
|
||||
<ChevronRight class="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{#if showViewAll}
|
||||
<div class="mt-2 -mr-2 flex justify-end px-4">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 text-xs text-muted-foreground hover:text-foreground"
|
||||
onclick={() => (viewAllDialogOpen = true)}
|
||||
>
|
||||
View all
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -194,3 +267,13 @@
|
||||
textContent={previewItem.textContent}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<ChatAttachmentsViewAllDialog
|
||||
bind:open={viewAllDialogOpen}
|
||||
{uploadedFiles}
|
||||
{attachments}
|
||||
{readonly}
|
||||
{onFileRemove}
|
||||
imageHeight="h-64"
|
||||
{imageClass}
|
||||
/>
|
||||
|
||||
+203
@@ -0,0 +1,203 @@
|
||||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
|
||||
import { FileTypeCategory } from '$lib/enums/files';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
|
||||
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
||||
|
||||
interface Props {
|
||||
open?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
imageClass?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
open = $bindable(false),
|
||||
uploadedFiles = [],
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
onFileRemove,
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto',
|
||||
imageClass = ''
|
||||
}: Props = $props();
|
||||
|
||||
let previewDialogOpen = $state(false);
|
||||
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
|
||||
|
||||
let displayItems = $derived(getDisplayItems());
|
||||
let imageItems = $derived(displayItems.filter((item) => item.isImage));
|
||||
let fileItems = $derived(displayItems.filter((item) => !item.isImage));
|
||||
|
||||
function getDisplayItems(): ChatAttachmentDisplayItem[] {
|
||||
const items: ChatAttachmentDisplayItem[] = [];
|
||||
|
||||
for (const file of uploadedFiles) {
|
||||
items.push({
|
||||
id: file.id,
|
||||
name: file.name,
|
||||
size: file.size,
|
||||
preview: file.preview,
|
||||
type: file.type,
|
||||
isImage: getFileTypeCategory(file.type) === FileTypeCategory.IMAGE,
|
||||
uploadedFile: file,
|
||||
textContent: file.textContent
|
||||
});
|
||||
}
|
||||
|
||||
for (const [index, attachment] of attachments.entries()) {
|
||||
if (attachment.type === 'imageFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
preview: attachment.base64Url,
|
||||
type: 'image',
|
||||
isImage: true,
|
||||
attachment,
|
||||
attachmentIndex: index
|
||||
});
|
||||
} else if (attachment.type === 'textFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: 'text',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
} else if (attachment.type === 'context') {
|
||||
// Legacy format from old webui - treat as text file
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: 'text',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
} else if (attachment.type === 'audioFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: attachment.mimeType || 'audio',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index
|
||||
});
|
||||
} else if (attachment.type === 'pdfFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: 'application/pdf',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return items.reverse();
|
||||
}
|
||||
|
||||
function openPreview(item: (typeof displayItems)[0], event?: Event) {
|
||||
if (event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
}
|
||||
|
||||
previewItem = {
|
||||
uploadedFile: item.uploadedFile,
|
||||
attachment: item.attachment,
|
||||
preview: item.preview,
|
||||
name: item.name,
|
||||
type: item.type,
|
||||
size: item.size,
|
||||
textContent: item.textContent
|
||||
};
|
||||
previewDialogOpen = true;
|
||||
}
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Portal>
|
||||
<Dialog.Overlay />
|
||||
|
||||
<Dialog.Content class="flex !max-h-[90vh] !max-w-6xl flex-col">
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>All Attachments ({displayItems.length})</Dialog.Title>
|
||||
<Dialog.Description class="text-sm text-muted-foreground">
|
||||
View and manage all attached files
|
||||
</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
|
||||
{#if fileItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each fileItems as item (item.id)}
|
||||
<ChatAttachmentFilePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
type={item.type}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if imageItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each imageItems as item (item.id)}
|
||||
{#if item.preview}
|
||||
<ChatAttachmentImagePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
</Dialog.Root>
|
||||
|
||||
{#if previewItem}
|
||||
<ChatAttachmentPreviewDialog
|
||||
bind:open={previewDialogOpen}
|
||||
uploadedFile={previewItem.uploadedFile}
|
||||
attachment={previewItem.attachment}
|
||||
preview={previewItem.preview}
|
||||
name={previewItem.name}
|
||||
type={previewItem.type}
|
||||
size={previewItem.size}
|
||||
textContent={previewItem.textContent}
|
||||
/>
|
||||
{/if}
|
||||
@@ -232,7 +232,13 @@
|
||||
onsubmit={handleSubmit}
|
||||
class="{INPUT_CLASSES} border-radius-bottom-none mx-auto max-w-[48rem] overflow-hidden rounded-3xl backdrop-blur-md {className}"
|
||||
>
|
||||
<ChatAttachmentsList bind:uploadedFiles {onFileRemove} class="mb-3 px-5 pt-5" />
|
||||
<ChatAttachmentsList
|
||||
bind:uploadedFiles
|
||||
{onFileRemove}
|
||||
limitToSingleRow
|
||||
class="py-5"
|
||||
style="scroll-padding: 1rem;"
|
||||
/>
|
||||
|
||||
<div
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl px-5 py-3 shadow-sm transition-all focus-within:shadow-md"
|
||||
|
||||
@@ -333,7 +333,7 @@
|
||||
ondrop={handleDrop}
|
||||
role="main"
|
||||
>
|
||||
<div class="w-full max-w-2xl px-4">
|
||||
<div class="w-full max-w-[48rem] px-4">
|
||||
<div class="mb-8 text-center" in:fade={{ duration: 300 }}>
|
||||
<h1 class="mb-2 text-3xl font-semibold tracking-tight">llama.cpp</h1>
|
||||
|
||||
@@ -368,7 +368,7 @@
|
||||
<AlertDialog.Portal>
|
||||
<AlertDialog.Overlay />
|
||||
|
||||
<AlertDialog.Content class="max-w-md">
|
||||
<AlertDialog.Content class="flex max-w-md flex-col">
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title>File Upload Error</AlertDialog.Title>
|
||||
|
||||
@@ -377,7 +377,7 @@
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div class="!max-h-[50vh] min-h-0 flex-1 space-y-4 overflow-y-auto">
|
||||
{#if fileErrorData.generallyUnsupported.length > 0}
|
||||
<div class="space-y-2">
|
||||
<h4 class="text-sm font-medium text-destructive">Unsupported File Types</h4>
|
||||
@@ -398,8 +398,6 @@
|
||||
|
||||
{#if fileErrorData.modalityUnsupported.length > 0}
|
||||
<div class="space-y-2">
|
||||
<h4 class="text-sm font-medium text-destructive">Model Compatibility Issues</h4>
|
||||
|
||||
<div class="space-y-1">
|
||||
{#each fileErrorData.modalityUnsupported as file (file.name)}
|
||||
<div class="rounded-md bg-destructive/10 px-3 py-2">
|
||||
@@ -415,14 +413,14 @@
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="rounded-md bg-muted/50 p-3">
|
||||
<h4 class="mb-2 text-sm font-medium">This model supports:</h4>
|
||||
<div class="rounded-md bg-muted/50 p-3">
|
||||
<h4 class="mb-2 text-sm font-medium">This model supports:</h4>
|
||||
|
||||
<p class="text-sm text-muted-foreground">
|
||||
{fileErrorData.supportedTypes.join(', ')}
|
||||
</p>
|
||||
</div>
|
||||
<p class="text-sm text-muted-foreground">
|
||||
{fileErrorData.supportedTypes.join(', ')}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<AlertDialog.Footer>
|
||||
|
||||
@@ -2,6 +2,7 @@ export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttac
|
||||
export { default as ChatAttachmentFilePreview } from './chat/ChatAttachments/ChatAttachmentFilePreview.svelte';
|
||||
export { default as ChatAttachmentImagePreview } from './chat/ChatAttachments/ChatAttachmentImagePreview.svelte';
|
||||
export { default as ChatAttachmentPreviewDialog } from './chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte';
|
||||
export { default as ChatAttachmentsViewAllDialog } from './chat/ChatAttachments/ChatAttachmentsViewAllDialog.svelte';
|
||||
|
||||
export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte';
|
||||
export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte';
|
||||
@@ -42,6 +43,8 @@ export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.sve
|
||||
|
||||
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
|
||||
|
||||
export { default as RemoveButton } from './misc/RemoveButton.svelte';
|
||||
|
||||
export { default as ServerStatus } from './server/ServerStatus.svelte';
|
||||
export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte';
|
||||
export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte';
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
onRemove?: (id: string) => void;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { id, onRemove, class: className = '' }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 bg-white/20 p-0 hover:bg-white/30 {className}"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(id);
|
||||
}}
|
||||
aria-label="Remove file"
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
+23
@@ -11,6 +11,29 @@ export interface ChatUploadedFile {
|
||||
textContent?: string;
|
||||
}
|
||||
|
||||
export interface ChatAttachmentDisplayItem {
|
||||
id: string;
|
||||
name: string;
|
||||
size?: number;
|
||||
preview?: string;
|
||||
type: string;
|
||||
isImage: boolean;
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
attachmentIndex?: number;
|
||||
textContent?: string;
|
||||
}
|
||||
|
||||
export interface ChatAttachmentPreviewItem {
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
preview?: string;
|
||||
name?: string;
|
||||
type?: string;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
}
|
||||
|
||||
export interface ChatMessageSiblingInfo {
|
||||
message: DatabaseMessage;
|
||||
siblingIds: string[];
|
||||
|
||||
Reference in New Issue
Block a user