mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 17:47:40 +02:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cb623de3fc | |||
| 7aaeedc098 | |||
| 3347e6d904 | |||
| 1a139644a8 | |||
| 2376b7758c | |||
| dbed61294a | |||
| 80deff3648 | |||
| 8b1c339bd2 | |||
| 416e7c7f47 | |||
| 5b2093becc | |||
| 52e5d421f1 | |||
| 4db5641210 | |||
| 72bd7321a7 | |||
| 22e1ce2f81 | |||
| 1411d9275a | |||
| 662192e1dc | |||
| 24dc769f1b | |||
| 4dca015b7e | |||
| 9a8860cf5d | |||
| 9d3ef4809f | |||
| c7b7db0445 | |||
| 1568d13c2c | |||
| 439342ea0b | |||
| 234ae7d7bd | |||
| 38eaf32af1 | |||
| 9b17d74ab7 | |||
| e1fcf8b09b | |||
| 6cd0cf72ce | |||
| d396b43748 | |||
| 45c6ef7307 | |||
| 2606b0adab | |||
| 307772fcda | |||
| f1bad23f88 | |||
| becc4816dd | |||
| c4abcb2457 | |||
| 389ac78b26 |
@@ -9,7 +9,7 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
|
||||
- **Size**: ~200k+ lines of code across 1000+ files
|
||||
- **Architecture**: Modular design with main library (`libllama`) and 40+ executable tools/examples
|
||||
- **Core dependency**: ggml tensor library (vendored in `ggml/` directory)
|
||||
- **Backends supported**: CPU (AVX/NEON optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
||||
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
|
||||
- **License**: MIT
|
||||
|
||||
## Build Instructions
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
name: CI (AMD)
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'.github/workflows/build-amd.yml',
|
||||
'**/CMakeLists.txt',
|
||||
'**/.cmake',
|
||||
'**/*.h',
|
||||
'**/*.hpp',
|
||||
'**/*.c',
|
||||
'**/*.cpp',
|
||||
'**/*.cu',
|
||||
'**/*.cuh',
|
||||
'**/*.comp'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
@@ -1599,6 +1599,34 @@ jobs:
|
||||
run: |
|
||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-metal:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ range of hardware - locally and in the cloud.
|
||||
- Plain C/C++ implementation without any dependencies
|
||||
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
|
||||
- AVX, AVX2, AVX512 and AMX support for x86 architectures
|
||||
- RVV, ZVFH, ZFH and ZICBOP support for RISC-V architectures
|
||||
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
|
||||
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
|
||||
- Vulkan and SYCL backend support
|
||||
|
||||
+1
-5
@@ -355,11 +355,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
void common_init() {
|
||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}, NULL);
|
||||
llama_log_set(common_log_default_callback, NULL);
|
||||
|
||||
#ifdef NDEBUG
|
||||
const char * build_type = "";
|
||||
|
||||
@@ -442,3 +442,9 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||
log->set_timestamps(timestamps);
|
||||
}
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,8 @@ extern int common_log_verbosity_thold;
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
||||
|
||||
// the common_log uses an internal worker thread to print/write log messages
|
||||
// when the worker thread is paused, incoming log messages are discarded
|
||||
struct common_log;
|
||||
|
||||
+81
-43
@@ -189,10 +189,10 @@ class ModelBase:
|
||||
return tensors
|
||||
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
|
||||
is_safetensors: bool = len(part_names) > 0
|
||||
if not is_safetensors:
|
||||
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
|
||||
|
||||
tensor_names_from_index: set[str] = set()
|
||||
|
||||
@@ -209,6 +209,7 @@ class ModelBase:
|
||||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||
tensor_names_from_index.update(weight_map.keys())
|
||||
part_names |= set(weight_map.values())
|
||||
else:
|
||||
weight_map = {}
|
||||
else:
|
||||
@@ -825,6 +826,15 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_expert_group_used_count(n_group_used)
|
||||
logger.info(f"gguf: expert groups used count = {n_group_used}")
|
||||
|
||||
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None:
|
||||
if score_func == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif score_func == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
|
||||
logger.info(f"gguf: expert score gating function = {score_func}")
|
||||
|
||||
if (head_dim := self.hparams.get("head_dim")) is not None:
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
self.gguf_writer.add_value_length(head_dim)
|
||||
@@ -1124,6 +1134,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
|
||||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
|
||||
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
|
||||
res = "afmoe"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
|
||||
res = "bailingmoe2"
|
||||
@@ -2533,6 +2546,72 @@ class ArceeModel(LlamaModel):
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
|
||||
@ModelBase.register("AfmoeForCausalLM")
|
||||
class AfmoeModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.AFMOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# MoE parameters
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
|
||||
self.gguf_writer.add_expert_shared_count(n_shared_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
|
||||
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
|
||||
|
||||
# Route normalization and scaling
|
||||
if (route_norm := self.hparams.get("route_norm")) is not None:
|
||||
self.gguf_writer.add_expert_weights_norm(route_norm)
|
||||
if (route_scale := self.hparams.get("route_scale")) is not None:
|
||||
self.gguf_writer.add_expert_weights_scale(route_scale)
|
||||
|
||||
# Sliding window attention
|
||||
if (sliding_window := self.hparams.get("sliding_window")) is not None:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Handle expert weights - they're already merged in the HF format
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["gate_proj", "up_proj", "down_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename_to_retrieve])
|
||||
del self._experts[bid][ename_to_retrieve]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"LlavaForConditionalGeneration", # pixtral
|
||||
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||
@@ -7104,13 +7183,6 @@ class DeepseekV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["scoring_func"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["scoring_func"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
@@ -7216,12 +7288,6 @@ class MiniMaxM2Model(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hparams["scoring_func"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif self.hparams["scoring_func"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
|
||||
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
|
||||
@@ -7314,11 +7380,6 @@ class Dots1Model(Qwen2MoeModel):
|
||||
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
|
||||
|
||||
if self.hparams["scoring_func"] == "noaux_tc":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
@@ -7779,12 +7840,6 @@ class Glm4MoeModel(TextModel):
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
|
||||
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
|
||||
|
||||
# Patch broken chat template
|
||||
if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
|
||||
special_vocab.chat_template = special_vocab.chat_template.replace(
|
||||
"""{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
|
||||
"""{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
|
||||
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
@@ -8639,13 +8694,6 @@ class BailingMoeV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["score_function"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["score_function"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
|
||||
|
||||
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
|
||||
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
|
||||
|
||||
@@ -9341,16 +9389,6 @@ class HunYuanModel(TextModel):
|
||||
class SmolLM3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMOLLM3
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# remove unsupported array slicing in chat template
|
||||
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
if tokenizer.chat_template is not None:
|
||||
chat_template = tokenizer.chat_template.replace("[:]", "")
|
||||
self.gguf_writer.add_chat_template(chat_template)
|
||||
|
||||
|
||||
@ModelBase.register("GptOssForCausalLM")
|
||||
class GptOssModel(TextModel):
|
||||
|
||||
@@ -139,6 +139,7 @@ models = [
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||
|
||||
+49
-44
@@ -14,103 +14,108 @@ Legend:
|
||||
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
+16067
-5133
File diff suppressed because it is too large
Load Diff
+16224
-6894
File diff suppressed because it is too large
Load Diff
+2348
-149
File diff suppressed because it is too large
Load Diff
+14536
-4360
File diff suppressed because it is too large
Load Diff
@@ -475,6 +475,7 @@ extern "C" {
|
||||
GGML_OP_COS,
|
||||
GGML_OP_SUM,
|
||||
GGML_OP_SUM_ROWS,
|
||||
GGML_OP_CUMSUM,
|
||||
GGML_OP_MEAN,
|
||||
GGML_OP_ARGMAX,
|
||||
GGML_OP_COUNT_EQUAL,
|
||||
@@ -530,6 +531,8 @@ extern "C" {
|
||||
GGML_OP_TIMESTEP_EMBEDDING,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
GGML_OP_TRI,
|
||||
GGML_OP_FILL,
|
||||
|
||||
GGML_OP_FLASH_ATTN_EXT,
|
||||
GGML_OP_FLASH_ATTN_BACK,
|
||||
@@ -542,6 +545,7 @@ extern "C" {
|
||||
GGML_OP_RWKV_WKV6,
|
||||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
GGML_OP_RWKV_WKV7,
|
||||
GGML_OP_SOLVE_TRI,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
@@ -576,6 +580,8 @@ extern "C" {
|
||||
GGML_UNARY_OP_HARDSWISH,
|
||||
GGML_UNARY_OP_HARDSIGMOID,
|
||||
GGML_UNARY_OP_EXP,
|
||||
GGML_UNARY_OP_EXPM1,
|
||||
GGML_UNARY_OP_SOFTPLUS,
|
||||
GGML_UNARY_OP_GELU_ERF,
|
||||
GGML_UNARY_OP_XIELU,
|
||||
GGML_UNARY_OP_FLOOR,
|
||||
@@ -620,6 +626,13 @@ extern "C" {
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
};
|
||||
|
||||
enum ggml_tri_type {
|
||||
GGML_TRI_TYPE_UPPER_DIAG = 0,
|
||||
GGML_TRI_TYPE_UPPER = 1,
|
||||
GGML_TRI_TYPE_LOWER_DIAG = 2,
|
||||
GGML_TRI_TYPE_LOWER = 3
|
||||
};
|
||||
|
||||
struct ggml_init_params {
|
||||
// memory pool
|
||||
size_t mem_size; // bytes
|
||||
@@ -957,6 +970,22 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_expm1(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_expm1_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_softplus(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_softplus_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_sin(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
@@ -983,6 +1012,10 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cumsum(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// mean along rows
|
||||
GGML_API struct ggml_tensor * ggml_mean(
|
||||
struct ggml_context * ctx,
|
||||
@@ -2187,6 +2220,23 @@ extern "C" {
|
||||
int shift2,
|
||||
int shift3);
|
||||
|
||||
// Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
|
||||
// zeroes everywhere outside the masked area
|
||||
GGML_API struct ggml_tensor * ggml_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_tri_type type);
|
||||
|
||||
// Fill tensor a with constant c
|
||||
GGML_API struct ggml_tensor * ggml_fill(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_fill_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c);
|
||||
|
||||
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||||
// timesteps: [N,]
|
||||
@@ -2356,6 +2406,27 @@ extern "C" {
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
|
||||
* without zeroes on the diagonal (i.e. invertible).
|
||||
* B can have any number of columns, but must have the same number of rows as A
|
||||
* If A is [n, n] and B is [n, m], then the result will be [n, m] as well
|
||||
* Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
|
||||
* where n > 100 sparingly, pre-chunk if necessary.
|
||||
*
|
||||
* If left = false, solves xA=B instead
|
||||
* If lower = false, assumes upper triangular instead
|
||||
* If uni = true, assumes diagonal of A to be all ones (will override actual values)
|
||||
*
|
||||
* TODO: currently only lower, right, non-unitriangular variant is implemented
|
||||
*/
|
||||
GGML_API struct ggml_tensor * ggml_solve_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
bool left,
|
||||
bool lower,
|
||||
bool uni);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -48,15 +48,14 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
||||
default:
|
||||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
// If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
|
||||
// added.
|
||||
int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
|
||||
@@ -87,10 +86,20 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
std::reverse(acl_ne, acl_ne + final_dims);
|
||||
std::reverse(acl_stride, acl_stride + final_dims);
|
||||
|
||||
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
||||
elem_offset, format, &acl_storage_len, 1, tensor->data);
|
||||
aclTensor * raw = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, elem_offset,
|
||||
format, &acl_storage_len, 1, tensor->data);
|
||||
|
||||
return acl_tensor;
|
||||
return acl_tensor_ptr(raw);
|
||||
}
|
||||
|
||||
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size) {
|
||||
aclIntArray * raw = aclCreateIntArray(value, size);
|
||||
return acl_int_array_ptr(raw);
|
||||
}
|
||||
|
||||
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType) {
|
||||
aclScalar * raw = aclCreateScalar(value, dataType);
|
||||
return acl_scalar_ptr(raw);
|
||||
}
|
||||
|
||||
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
|
||||
|
||||
@@ -23,11 +23,12 @@
|
||||
#ifndef CANN_ACL_TENSOR_H
|
||||
#define CANN_ACL_TENSOR_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include "common.h"
|
||||
|
||||
#include <aclnn/aclnn_base.h>
|
||||
#include "common.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
/**
|
||||
* @brief Maps a ggml_type to its corresponding aclDataType.
|
||||
@@ -43,6 +44,20 @@
|
||||
*/
|
||||
aclDataType ggml_cann_type_mapping(ggml_type type);
|
||||
|
||||
// Deleter for acl objects.
|
||||
template <typename T, aclError (*DestroyFunc)(const T *)> struct acl_deleter {
|
||||
void operator()(T * ptr) const noexcept {
|
||||
if (ptr) {
|
||||
ACL_CHECK(DestroyFunc(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using acl_tensor_ptr = std::unique_ptr<aclTensor, acl_deleter<aclTensor, aclDestroyTensor>>;
|
||||
using acl_int_array_ptr = std::unique_ptr<aclIntArray, acl_deleter<aclIntArray, aclDestroyIntArray>>;
|
||||
using acl_scalar_ptr = std::unique_ptr<aclScalar, acl_deleter<aclScalar, aclDestroyScalar>>;
|
||||
using acl_tensor_list_ptr = std::unique_ptr<aclTensorList, acl_deleter<aclTensorList, aclDestroyTensorList>>;
|
||||
|
||||
/**
|
||||
* @brief Creates an ACL tensor from a ggml_tensor with optional shape.
|
||||
*
|
||||
@@ -62,12 +77,12 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
|
||||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
|
||||
/**
|
||||
* @brief Template for creating an ACL tensor from provided parameters. typename TYPE
|
||||
@@ -90,14 +105,14 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
template <typename TYPE>
|
||||
aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
acl_tensor_ptr ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
int64_t tmp_ne[GGML_MAX_DIMS * 2];
|
||||
int64_t tmp_stride[GGML_MAX_DIMS * 2];
|
||||
|
||||
@@ -114,10 +129,75 @@ aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||
std::reverse(tmp_ne, tmp_ne + dims);
|
||||
std::reverse(tmp_stride, tmp_stride + dims);
|
||||
|
||||
aclTensor * acl_tensor =
|
||||
aclTensor * raw =
|
||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
|
||||
|
||||
return acl_tensor;
|
||||
return acl_tensor_ptr(raw);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create an ACL int array resource wrapped in a smart pointer.
|
||||
*
|
||||
* This function constructs an aclIntArray from the provided int64_t values
|
||||
* and returns it as an acl_int_array_ptr (a std::unique_ptr with a custom
|
||||
* deleter). The returned pointer owns the ACL resource and will automatically
|
||||
* destroy it via aclDestroyIntArray().
|
||||
*
|
||||
* @param value Pointer to the int64_t elements.
|
||||
* @param size Number of elements in value.
|
||||
*
|
||||
* @return A smart pointer managing the created ACL int array.
|
||||
*/
|
||||
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size);
|
||||
|
||||
/**
|
||||
* @brief Create an ACL scalar resource wrapped in a smart pointer.
|
||||
*
|
||||
* This function constructs an aclScalar from the raw value pointer and ACL
|
||||
* data type, then returns it as an acl_scalar_ptr (a std::unique_ptr with
|
||||
* a custom deleter). The returned pointer owns the ACL scalar and will
|
||||
* automatically destroy it via aclDestroyScalar().
|
||||
*
|
||||
* @param value Pointer to the raw scalar memory.
|
||||
* @param dataType ACL data type of the scalar.
|
||||
*
|
||||
* @return A smart pointer managing the created ACL scalar.
|
||||
*/
|
||||
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType);
|
||||
|
||||
/**
|
||||
* @brief Create an ACL tensor list from multiple tensor smart pointers.
|
||||
*
|
||||
* This function accepts a variadic list of acl_tensor_ptr (a unique_ptr with
|
||||
* custom deleter) and produces an aclTensorList using aclCreateTensorList().
|
||||
*
|
||||
* The lifecycle management of the tensor objects changes as follows:
|
||||
* - aclCreateTensorList() takes ownership of the tensors
|
||||
* - Each input smart pointer releases ownership using release()
|
||||
* - As a result, the tensors will NOT be destroyed by unique_ptr
|
||||
* - Instead, they will be destroyed when aclDestroyTensorList() is called
|
||||
*
|
||||
* This ensures correct ownership transfer and prevents double-free situations.
|
||||
*
|
||||
* @param acl_tensor_ptr Variadic template parameter; each argument must be
|
||||
* a unique_ptr-like type supporting get() and release().
|
||||
*
|
||||
* @param tensors Variadic list of acl_tensor_ptr objects. Ownership of
|
||||
* each tensor is transferred away from these smart pointers.
|
||||
*
|
||||
* @return A smart pointer (acl_tensor_list_ptr) owning the created ACL tensor list.
|
||||
*
|
||||
* @note This implementation is C++11 compatible. The ownership-release process is
|
||||
* executed using a pack expansion inside an initializer list.
|
||||
*/
|
||||
template <typename... acl_tensor_ptr> acl_tensor_list_ptr ggml_cann_create_tensor_list(acl_tensor_ptr &&... tensors) {
|
||||
aclTensor * raw_tensors[] = { tensors.get()... };
|
||||
aclTensorList * raw = aclCreateTensorList(raw_tensors, sizeof...(tensors));
|
||||
// aclTensor will release by aclTensorList, so release ownership without
|
||||
// destroying the tensor
|
||||
int dummy[] = { (tensors.release(), 0)... };
|
||||
GGML_UNUSED(dummy);
|
||||
return acl_tensor_list_ptr(raw);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
+587
-691
File diff suppressed because it is too large
Load Diff
+42
-197
@@ -23,33 +23,35 @@
|
||||
#ifndef CANN_ACLNN_OPS
|
||||
#define CANN_ACLNN_OPS
|
||||
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <aclnnop/aclnn_abs.h>
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
#include <aclnnop/aclnn_arange.h>
|
||||
#include <aclnnop/aclnn_argsort.h>
|
||||
#include <aclnnop/aclnn_cat.h>
|
||||
#include <aclnnop/aclnn_clamp.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
#include <aclnnop/aclnn_gelu.h>
|
||||
#include <aclnnop/aclnn_gelu_v2.h>
|
||||
#include <aclnnop/aclnn_sigmoid.h>
|
||||
#include <aclnnop/aclnn_hardsigmoid.h>
|
||||
#include <aclnnop/aclnn_hardswish.h>
|
||||
#include <aclnnop/aclnn_leaky_relu.h>
|
||||
#include <aclnnop/aclnn_relu.h>
|
||||
#include <aclnnop/aclnn_silu.h>
|
||||
#include <aclnnop/aclnn_tanh.h>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_log.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_norm.h>
|
||||
#include <aclnnop/aclnn_logsoftmax.h>
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_norm.h>
|
||||
#include <aclnnop/aclnn_relu.h>
|
||||
#include <aclnnop/aclnn_sigmoid.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_silu.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_tanh.h>
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_set>
|
||||
|
||||
/**
|
||||
* @brief Repeats a ggml tensor along each dimension to match the dimensions
|
||||
@@ -688,12 +690,12 @@ void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor *
|
||||
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
|
||||
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
|
||||
*/
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
aclTensor ** acl_src0,
|
||||
aclTensor ** acl_src1,
|
||||
aclTensor ** acl_dst);
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
acl_tensor_ptr & acl_src0,
|
||||
acl_tensor_ptr & acl_src1,
|
||||
acl_tensor_ptr & acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
|
||||
@@ -873,83 +875,6 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
|
||||
(vec.emplace_back(make_acl_resource(args)), ...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Task class that wraps the execution of an aclnn function call.
|
||||
*/
|
||||
class aclnn_task : public cann_task {
|
||||
public:
|
||||
aclnn_task(aclnn_func_t aclnn_func,
|
||||
void * workspace_addr,
|
||||
uint64_t workspace_size,
|
||||
aclOpExecutor * executor,
|
||||
aclrtStream stream) :
|
||||
aclnn_func_(aclnn_func),
|
||||
workspace_addr_(workspace_addr),
|
||||
workspace_size_(workspace_size),
|
||||
executor_(executor),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
|
||||
private:
|
||||
aclnn_func_t aclnn_func_;
|
||||
void * workspace_addr_;
|
||||
uint64_t workspace_size_;
|
||||
aclOpExecutor * executor_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class that releases ACL resources after usage.
|
||||
*/
|
||||
class release_resource_task : public cann_task {
|
||||
public:
|
||||
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
|
||||
|
||||
virtual void run_task() override { resource_.clear(); }
|
||||
private:
|
||||
std::vector<any_acl_resource> resource_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory copy operations.
|
||||
*/
|
||||
class async_memcpy_task : public cann_task {
|
||||
public:
|
||||
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
|
||||
dst_(dst),
|
||||
src_(src),
|
||||
size_(size),
|
||||
kind_(kind),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
|
||||
private:
|
||||
void * dst_;
|
||||
const void * src_;
|
||||
size_t size_;
|
||||
aclrtMemcpyKind kind_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory set operations.
|
||||
*/
|
||||
class async_memset_task : public cann_task {
|
||||
public:
|
||||
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
|
||||
buffer_(buffer),
|
||||
size_(size),
|
||||
value_(value),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
|
||||
private:
|
||||
void * buffer_;
|
||||
size_t size_;
|
||||
int32_t value_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Launches an asynchronous task using the memory allocator.
|
||||
*
|
||||
@@ -968,95 +893,20 @@ class async_memset_task : public cann_task {
|
||||
* same stream are executed in queue order.
|
||||
*/
|
||||
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
if (CTX.async_mode) { \
|
||||
auto task = \
|
||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
|
||||
CTX.task_queue.submit_task(std::move(task)); \
|
||||
} else { \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
} \
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Registers and releases multiple ACL resources, optionally deferring the release
|
||||
* using a task.
|
||||
*
|
||||
* @tparam Args Types of the ACL resources.
|
||||
* @param ctx Backend context which manages task submission and async mode.
|
||||
* @param args Pointers to ACL resources to be released.
|
||||
*/
|
||||
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
||||
std::vector<any_acl_resource> resources;
|
||||
register_acl_resources(resources, std::forward<Args>(args)...);
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<release_resource_task>(std::move(resources));
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param dst Destination memory address.
|
||||
* @param src Source memory address.
|
||||
* @param len Size of memory to copy (in bytes).
|
||||
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
|
||||
*/
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx->async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
|
||||
ctx->task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param buffer Memory buffer to be set.
|
||||
* @param size Size of the memory buffer (in bytes).
|
||||
* @param value Value to set in the buffer.
|
||||
*/
|
||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
|
||||
*
|
||||
@@ -1129,15 +979,11 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
aclTensor * acl_src0;
|
||||
aclTensor * acl_src1;
|
||||
aclTensor * acl_dst;
|
||||
acl_tensor_ptr acl_src0, acl_src1, acl_dst;
|
||||
|
||||
// Need bcast
|
||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
||||
binary_op(ctx, acl_src0, acl_src1, acl_dst);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
bcast_shape(src0, src1, dst, acl_src0, acl_src1, acl_dst);
|
||||
binary_op(ctx, acl_src0.get(), acl_src1.get(), acl_dst.get());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1147,7 +993,7 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
|
||||
* and stores the result in the destination tensor.
|
||||
*
|
||||
* @tparam unary_op A callable with the signature:
|
||||
* void(ggml_backend_cann_context&, aclTensor*, aclTensor*)
|
||||
* void(ggml_backend_cann_context&, aclTensor *, aclTensor *)
|
||||
* where the first aclTensor is the source and the second is the destination.
|
||||
* @param ctx The CANN backend context for managing resources and execution.
|
||||
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
|
||||
@@ -1156,11 +1002,10 @@ template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
|
||||
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
aclTensor * acl_src = ggml_cann_create_tensor(src);
|
||||
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
unary_op(ctx, acl_src, acl_dst);
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
||||
unary_op(ctx, acl_src.get(), acl_dst.get());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
+19
-150
@@ -23,26 +23,26 @@
|
||||
#ifndef CANN_COMMON_H
|
||||
#define CANN_COMMON_H
|
||||
|
||||
#include <acl/acl.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <list>
|
||||
|
||||
#include "../ggml-impl.h"
|
||||
#include "../include/ggml-cann.h"
|
||||
#include "../include/ggml.h"
|
||||
#include "../ggml-impl.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#define MATRIX_ROW_PADDING 512
|
||||
#define GGML_CANN_MAX_STREAMS 8
|
||||
@@ -214,130 +214,6 @@ struct ggml_cann_pool_alloc {
|
||||
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Function pointer type for ACLNN operator calls.
|
||||
*/
|
||||
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
|
||||
|
||||
/**
|
||||
* @brief Base class for all CANN tasks to be submitted to the task queue.
|
||||
*
|
||||
* Users should override the run_task() method with actual task logic.
|
||||
*/
|
||||
class cann_task {
|
||||
public:
|
||||
virtual void run_task() {}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
|
||||
*/
|
||||
class cann_task_queue {
|
||||
public:
|
||||
/**
|
||||
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
|
||||
*
|
||||
* @param capacity Queue capacity. Must be a power of 2.
|
||||
* @param device Target device ID (used for context setting).
|
||||
*/
|
||||
explicit cann_task_queue(size_t capacity, int32_t device) :
|
||||
buffer_(capacity),
|
||||
capacity_(capacity),
|
||||
head_(0),
|
||||
tail_(0),
|
||||
running_(false),
|
||||
device_(device) {
|
||||
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
|
||||
mask_ = capacity_ - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Attempts to enqueue a task into the queue.
|
||||
*
|
||||
* @param item Unique pointer to the task.
|
||||
* @return true if the task was successfully enqueued, false if the queue was full.
|
||||
*/
|
||||
bool enqueue(std::unique_ptr<cann_task> && item) {
|
||||
size_t next_tail = (tail_ + 1) & mask_;
|
||||
|
||||
if (next_tail == head_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
buffer_[tail_] = std::move(item);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
tail_ = next_tail;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Submits a task to the queue, and starts the worker thread if not already running.
|
||||
*
|
||||
* @param task Task to be submitted.
|
||||
*/
|
||||
void submit_task(std::unique_ptr<cann_task> && task) {
|
||||
while (!enqueue(std::move(task))) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!running_) {
|
||||
running_ = true;
|
||||
thread_ = std::thread(&cann_task_queue::execute, this);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits until the queue is completely empty and no tasks are being processed.
|
||||
*/
|
||||
void wait() {
|
||||
while (running_ && head_ != tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stops the task queue and joins the worker thread.
|
||||
*/
|
||||
void stop() {
|
||||
running_ = false;
|
||||
if (thread_.joinable()) {
|
||||
thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Worker thread function that continuously dequeues and executes tasks.
|
||||
*/
|
||||
void execute() {
|
||||
ggml_cann_set_device(device_);
|
||||
|
||||
while (running_) {
|
||||
if (head_ == tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
buffer_[head_]->run_task();
|
||||
buffer_[head_].reset();
|
||||
head_ = (head_ + 1) & mask_;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<cann_task>> buffer_;
|
||||
const size_t capacity_;
|
||||
size_t mask_;
|
||||
size_t head_;
|
||||
size_t tail_;
|
||||
bool running_;
|
||||
std::thread thread_;
|
||||
int32_t device_;
|
||||
};
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
struct ggml_graph_node_properties {
|
||||
// dst tensor
|
||||
@@ -474,7 +350,6 @@ struct ggml_backend_cann_context {
|
||||
ggml_cann_graph_lru_cache graph_lru_cache;
|
||||
bool acl_graph_mode = true;
|
||||
#endif
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
// Rope Cache
|
||||
ggml_cann_rope_cache rope_cache;
|
||||
@@ -488,15 +363,10 @@ struct ggml_backend_cann_context {
|
||||
* @brief Constructor for initializing the context with a given device.
|
||||
* @param device Device ID.
|
||||
*/
|
||||
explicit ggml_backend_cann_context(int device) :
|
||||
device(device),
|
||||
name("CANN" + std::to_string(device)),
|
||||
task_queue(1024, device) {
|
||||
explicit ggml_backend_cann_context(int device) : device(device), name("CANN" + std::to_string(device)) {
|
||||
ggml_cann_set_device(device);
|
||||
description = aclrtGetSocName();
|
||||
|
||||
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
|
||||
#ifdef USE_ACL_GRAPH
|
||||
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
|
||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
@@ -509,7 +379,6 @@ struct ggml_backend_cann_context {
|
||||
*/
|
||||
~ggml_backend_cann_context() {
|
||||
ggml_cann_set_device(device);
|
||||
task_queue.stop();
|
||||
if (copy_event != nullptr) {
|
||||
ACL_CHECK(aclrtDestroyEvent(copy_event));
|
||||
}
|
||||
|
||||
@@ -22,24 +22,24 @@
|
||||
|
||||
#include "ggml-cann.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <stdarg.h>
|
||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-cann/aclnn_ops.h"
|
||||
#include "ggml-cann/common.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||
#include <stdarg.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
#include <unordered_set>
|
||||
#include <optional>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-cann/aclnn_ops.h"
|
||||
#include "ggml-cann/common.h"
|
||||
#include "ggml.h"
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
|
||||
@@ -1177,19 +1177,18 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
|
||||
* across calls. This reduces overhead from repeated memory allocation and deallocation.
|
||||
*/
|
||||
static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
|
||||
aclTensor * weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||
uint64_t workspaceSize = 0;
|
||||
acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor * executor;
|
||||
|
||||
// TransMatmulWeight
|
||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
|
||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
|
||||
// Avoid frequent malloc/free of the workspace.
|
||||
g_nz_workspaces[device].realloc(workspaceSize);
|
||||
|
||||
void * g_nz_workspace = g_nz_workspaces[device].get();
|
||||
|
||||
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
|
||||
ACL_CHECK(aclDestroyTensor(weightTransposed));
|
||||
}
|
||||
|
||||
// TODO: need handle tensor which has paddings.
|
||||
@@ -1641,7 +1640,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
|
||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||
},
|
||||
/* .device = */
|
||||
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
@@ -1949,7 +1948,8 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ggml_cann_async_memcpy(cann_ctx, (char *) tensor->data + offset, data, size, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
|
||||
cann_ctx->stream()));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1974,7 +1974,8 @@ static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ggml_cann_async_memcpy(cann_ctx, data, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
|
||||
cann_ctx->stream()));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -2035,7 +2036,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
|
||||
|
||||
// wait for task_queue empty to keep task order.
|
||||
cann_ctx_src->task_queue.wait();
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
|
||||
cann_ctx_src->stream()));
|
||||
// record event on src stream after the copy
|
||||
@@ -2068,7 +2068,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
*/
|
||||
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
||||
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
|
||||
cann_ctx->task_queue.wait();
|
||||
ggml_cann_set_device(cann_ctx->device);
|
||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||
}
|
||||
@@ -2485,6 +2484,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] > 896) {
|
||||
return false;
|
||||
}
|
||||
#ifdef ASCEND_310P
|
||||
if (!ggml_is_contiguous(op->src[0])) {
|
||||
return false;
|
||||
@@ -2521,10 +2523,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
// value of paddingW should be at most half of kernelW
|
||||
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
|
||||
}
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_REPEAT:
|
||||
|
||||
@@ -646,7 +646,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||
int64_t xstart = 0;
|
||||
int anr = nr - nr%16; // Used to align nr with boundary of 16
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
int anc = nc - nc%16; // Used to align nc with boundary of 16
|
||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||
@@ -1041,7 +1041,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
xstart = anc/8;
|
||||
y = 0;
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
|
||||
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
||||
|
||||
@@ -1989,7 +1989,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||
int64_t xstart = 0;
|
||||
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||
@@ -2727,7 +2727,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
xstart = anc/8;
|
||||
y = 0;
|
||||
}
|
||||
#endif //AVX512F
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
|
||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||
for (; y < anr / 4; y += 4) {
|
||||
@@ -3467,7 +3467,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
__m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);
|
||||
scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
|
||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||
|
||||
@@ -4947,7 +4947,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
y = 0;
|
||||
}
|
||||
|
||||
#endif //AVX512F
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
|
||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||
for (; y < anr / 4; y += 4) {
|
||||
|
||||
@@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_sum_rows(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CUMSUM:
|
||||
{
|
||||
ggml_compute_forward_cumsum(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
ggml_compute_forward_mean(params, tensor);
|
||||
@@ -1927,6 +1931,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_leaky_relu(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
ggml_compute_forward_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
ggml_compute_forward_fill(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
ggml_compute_forward_flash_attn_ext(params, tensor);
|
||||
@@ -1982,6 +1994,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
ggml_compute_forward_solve_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
{
|
||||
ggml_compute_forward_map_custom1(params, tensor);
|
||||
@@ -2140,6 +2156,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@@ -2157,6 +2176,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
@@ -2179,6 +2199,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
|
||||
+208
-2
@@ -9,6 +9,7 @@
|
||||
|
||||
#include <cfloat>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
// ggml_compute_forward_dup
|
||||
@@ -1395,6 +1396,56 @@ void ggml_compute_forward_sum(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_cumsum
|
||||
|
||||
static void ggml_compute_forward_cumsum_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne01);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne03);
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_cumsum(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_cumsum_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_sum_rows
|
||||
|
||||
static void ggml_compute_forward_sum_rows_f32(
|
||||
@@ -2141,6 +2192,83 @@ static void ggml_compute_forward_gelu(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_fill
|
||||
|
||||
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const float c = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, dst);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne2*ne1);
|
||||
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
||||
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
ggml_vec_set_f32(ne0, dst_ptr, c);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
ggml_compute_forward_fill_f32(params, dst);
|
||||
}
|
||||
|
||||
// ggml_compute_tri
|
||||
|
||||
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
bool (*bipred)(int, int);
|
||||
|
||||
switch (ttype) {
|
||||
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
||||
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
||||
default: GGML_ABORT("invalid tri type");
|
||||
}
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
for (int i0 = 0; i0 < ne0; ++i0) {
|
||||
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_tri_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_gelu_erf
|
||||
|
||||
static void ggml_compute_forward_gelu_erf_f32(
|
||||
@@ -8536,7 +8664,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
||||
const float dA = expf(dt_soft_plus * A[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
@@ -8633,7 +8761,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = ggml_softplus(dt[h]);
|
||||
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
// dim
|
||||
@@ -8916,6 +9044,14 @@ void ggml_compute_forward_unary(
|
||||
{
|
||||
ggml_compute_forward_xielu(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
{
|
||||
ggml_compute_forward_expm1(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
{
|
||||
ggml_compute_forward_softplus(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -9512,6 +9648,76 @@ void ggml_compute_forward_gla(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ne00 == ne01); // A must be square
|
||||
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
||||
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
||||
|
||||
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
||||
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t k = ne10; // number of RHS columns
|
||||
const int64_t n = ne11; // A is n×n
|
||||
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
||||
|
||||
// chunks per thread
|
||||
const int64_t dr = (nr + nth - 1)/nth;
|
||||
|
||||
// chunk range for this thread
|
||||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
||||
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
||||
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*k);
|
||||
const int64_t i02 = (ir - i03*ne02*k)/k;
|
||||
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
||||
|
||||
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
||||
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
||||
|
||||
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
||||
|
||||
for (int64_t i00 = 0; i00 < n; ++i00) {
|
||||
float sum = 0.0f;
|
||||
for (int64_t t = 0; t < i00; ++t) {
|
||||
sum += A_batch[i00 * n + t] * X_batch[i01 * n + t];
|
||||
}
|
||||
|
||||
const float diag = A_batch[i00 * n + i00];
|
||||
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
||||
|
||||
X_batch[i01 * n + i00] = (B_batch[i00 * k + i01] - sum) / diag;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_compute_forward_solve_tri_f32(params, dst);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_rwkv_wkv7
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
|
||||
@@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct
|
||||
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
@@ -81,6 +82,8 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
|
||||
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_flash_attn_back(
|
||||
const struct ggml_compute_params * params,
|
||||
@@ -96,6 +99,7 @@ void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params,
|
||||
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -1600,29 +1600,52 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
|
||||
void forward_mul_mat_one_chunk(ggml_compute_params * params,
|
||||
ggml_tensor * op,
|
||||
int64_t src0_start,
|
||||
int64_t src0_end,
|
||||
int64_t src1_start,
|
||||
int64_t src1_end) {
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
ggml_tensor * dst = op;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const void * src1_wdata = params->wdata;
|
||||
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
||||
|
||||
GGML_ASSERT(ne03 == 1 && ne13 == 1);
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
|
||||
const int64_t i12 = src1_start / ne1;
|
||||
const int64_t i11 = src1_start - i12 * ne1;
|
||||
|
||||
// Determine batch index
|
||||
const int64_t i02 = i12 / r2;
|
||||
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
|
||||
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
|
||||
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
|
||||
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
|
||||
|
||||
const int64_t nrows = src1_end - src1_start;
|
||||
const int64_t ncols = src0_end - src0_start;
|
||||
|
||||
GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (ne11 > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
if (nrows > 3) {
|
||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
|
||||
src0_ptr + src0_start * nb01, src1_ptr,
|
||||
nrows - (nrows % 4), ncols);
|
||||
}
|
||||
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
|
||||
ne01, src0_ptr + src0_start * nb01,
|
||||
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1647,6 +1670,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// TODO: General batched mul mat for 4D tensors
|
||||
// Currently only supports 3D tensors
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne3 == 1);
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
|
||||
@@ -1654,47 +1683,64 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
|
||||
char * wdata = static_cast<char *>(params->wdata);
|
||||
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
|
||||
assert(params->wsize >= nbw1 * ne11);
|
||||
assert(params->wsize >= nbw2 * ne12);
|
||||
|
||||
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
||||
|
||||
int64_t i11_processed = 0;
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
|
||||
}
|
||||
// INFO: Quantization is done in planes to avoid extra complexity in chunking.
|
||||
// Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
|
||||
// the planes are broadcast.
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
char * data_ptr = (char *) src1->data + i12 * nb12;
|
||||
char * wdata_ptr = wdata + i12 * nbw2;
|
||||
|
||||
i11_processed = ne11 - ne11 % 4;
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
||||
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
|
||||
}
|
||||
|
||||
const int64_t i11_processed = ne11 - ne11 % 4;
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
|
||||
}
|
||||
}
|
||||
|
||||
// disable for NUMA
|
||||
const bool disable_chunking = ggml_is_numa();
|
||||
|
||||
// 4x chunks per thread
|
||||
int64_t nr = ggml_nrows(op->src[0]);
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||
const int64_t nr0 = ggml_nrows(op->src[0]);
|
||||
|
||||
int nth_scaled = nth * 4;
|
||||
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
|
||||
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
|
||||
|
||||
// src1 is chunked only by full planes.
|
||||
// When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
|
||||
// to route them thorugh GEMV.
|
||||
// nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
|
||||
// to avoid affecting their performance
|
||||
int64_t nchunk1 = ne12;
|
||||
|
||||
// Ensure minimum chunk size to avoid alignment issues with high thread counts
|
||||
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
|
||||
const int64_t min_chunk_size = NB_COLS;
|
||||
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
|
||||
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
|
||||
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
|
||||
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||
}
|
||||
|
||||
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||
nchunk = nth;
|
||||
if (nth == 1 || nchunk0 < nth || disable_chunking) {
|
||||
nchunk0 = nth;
|
||||
}
|
||||
|
||||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
|
||||
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
|
||||
// This prevents creating too many tiny chunks that could overlap after alignment
|
||||
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
|
||||
if (nchunk > max_nchunk) {
|
||||
nchunk = max_nchunk;
|
||||
}
|
||||
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
|
||||
nchunk0 = MIN(nchunk0, max_nchunk);
|
||||
|
||||
if (ith == 0) {
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
@@ -1706,23 +1752,30 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
int64_t src0_start = (current_chunk * ne01) / nchunk;
|
||||
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
|
||||
int64_t src0_start = dr0 * ith0;
|
||||
int64_t src0_end = MIN(src0_start + dr0, nr0);
|
||||
|
||||
// full-plane range for src1
|
||||
int64_t src1_start = ith1 * ne11;
|
||||
int64_t src1_end = (ith1 + 1) * ne11;
|
||||
|
||||
// Align boundaries to NB_COLS - round up to ensure all data is included
|
||||
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
|
||||
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
if (src0_end > ne01) {
|
||||
src0_end = ne01;
|
||||
}
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
src0_end = MIN(src0_end, ne01);
|
||||
|
||||
// Make sure current plane is the last one before exiting
|
||||
if (src0_start >= src0_end) {
|
||||
break;
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
continue;
|
||||
}
|
||||
|
||||
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
|
||||
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
|
||||
@@ -73,6 +73,14 @@ static inline float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static inline float op_expm1(float x) {
|
||||
return expf(x) - 1.0f;
|
||||
}
|
||||
|
||||
static inline float op_softplus(float x) {
|
||||
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
|
||||
static inline float op_floor(float x) {
|
||||
return floorf(x);
|
||||
}
|
||||
@@ -290,6 +298,14 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
|
||||
unary_op<op_log>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_expm1>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_softplus>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_floor>(params, dst);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
|
||||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -1416,6 +1416,16 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (i == 0) {
|
||||
y[i] = x[i];
|
||||
} else {
|
||||
y[i] = y[i - 1] + x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
|
||||
ggml_float sum = 0.0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
|
||||
@@ -2527,6 +2527,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
ggml_cuda_op_trunc(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
ggml_cuda_op_expm1(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_cuda_op_softplus(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -3829,6 +3835,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
|
||||
@@ -81,6 +81,14 @@ static __device__ __forceinline__ float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_expm1(float x) {
|
||||
return expm1f(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_softplus(float x) {
|
||||
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_elu(float x) {
|
||||
return (x > 0.f) ? x : expm1f(x);
|
||||
}
|
||||
@@ -233,6 +241,14 @@ void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_trunc>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_expm1>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_softplus>(ctx, dst);
|
||||
}
|
||||
/* gated ops */
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
|
||||
@@ -61,6 +61,10 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) {
|
||||
}
|
||||
}
|
||||
|
||||
static inline float ggml_softplus(float input) {
|
||||
static inline float ggml_compute_softplus_f32(float input) {
|
||||
return (input > 20.0f) ? input : logf(1 + expf(input));
|
||||
}
|
||||
//
|
||||
|
||||
@@ -318,6 +318,44 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
@@ -943,6 +981,34 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_ARGSORT);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
ggml_sort_order order = (ggml_sort_order) op->op_params[0];
|
||||
|
||||
const char * order_str = "undefined";
|
||||
switch (order) {
|
||||
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
||||
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
|
||||
@@ -113,6 +113,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
@@ -125,6 +127,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -870,6 +870,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_SUM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
@@ -904,8 +905,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ARGSORT:
|
||||
// TODO: Support arbitrary column width
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
case GGML_OP_ARANGE:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
@@ -990,7 +989,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return false;
|
||||
}
|
||||
case GGML_TYPE_I32:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32;
|
||||
default:
|
||||
return false;
|
||||
};
|
||||
|
||||
@@ -612,6 +612,45 @@ typedef struct {
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_sum_rows;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t net0;
|
||||
int64_t net1;
|
||||
int64_t net2;
|
||||
int64_t net3;
|
||||
uint64_t nbt0;
|
||||
uint64_t nbt1;
|
||||
uint64_t nbt2;
|
||||
uint64_t nbt3;
|
||||
bool outb;
|
||||
} ggml_metal_kargs_cumsum_blk;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t net0;
|
||||
int64_t net1;
|
||||
int64_t net2;
|
||||
int64_t net3;
|
||||
uint64_t nbt0;
|
||||
uint64_t nbt1;
|
||||
uint64_t nbt2;
|
||||
uint64_t nbt3;
|
||||
} ggml_metal_kargs_cumsum_add;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
@@ -793,10 +832,28 @@ typedef struct {
|
||||
} ggml_metal_kargs_leaky_relu;
|
||||
|
||||
typedef struct {
|
||||
int64_t ncols;
|
||||
int64_t ncols_pad;
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
} ggml_metal_kargs_argsort;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t len;
|
||||
} ggml_metal_kargs_argsort_merge;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne0;
|
||||
float start;
|
||||
|
||||
@@ -311,6 +311,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_sum_rows(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_CUMSUM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_cumsum(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
n_fuse = ggml_metal_op_soft_max(ctx, idx);
|
||||
@@ -539,7 +543,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
||||
|
||||
@@ -585,7 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
@@ -694,7 +698,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float scale;
|
||||
float bias;
|
||||
@@ -733,7 +737,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float min;
|
||||
float max;
|
||||
@@ -772,7 +776,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
@@ -802,7 +806,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
if (op->src[1]) {
|
||||
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
|
||||
@@ -834,18 +838,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
|
||||
|
||||
//[encoder setComputePipelineState:pipeline];
|
||||
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
//if (src1) {
|
||||
// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
//} else {
|
||||
// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
//}
|
||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
//[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
//[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
@@ -907,7 +899,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -941,14 +933,6 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
|
||||
//[encoder setComputePipelineState:pipeline];
|
||||
//[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
//[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
@@ -961,6 +945,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
||||
|
||||
int nth = 1;
|
||||
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne00 <= nth*nth);
|
||||
|
||||
const int64_t net0 = (ne00 + nth - 1) / nth;
|
||||
const int64_t net1 = ne01;
|
||||
const int64_t net2 = ne02;
|
||||
const int64_t net3 = ne03;
|
||||
|
||||
const uint64_t nbt0 = sizeof(float);
|
||||
const uint64_t nbt1 = net0*nbt0;
|
||||
const uint64_t nbt2 = net1*nbt1;
|
||||
const uint64_t nbt3 = net2*nbt2;
|
||||
|
||||
const size_t smem = GGML_PAD(32*sizeof(float), 16);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += ggml_nbytes(op);
|
||||
|
||||
{
|
||||
ggml_metal_kargs_cumsum_blk args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.net0 =*/ net0,
|
||||
/*.net1 =*/ net1,
|
||||
/*.net2 =*/ net2,
|
||||
/*.net3 =*/ net3,
|
||||
/*.nbt0 =*/ nbt0,
|
||||
/*.nbt1 =*/ nbt1,
|
||||
/*.nbt2 =*/ nbt2,
|
||||
/*.nbt3 =*/ nbt3,
|
||||
/*.outb =*/ ne00 > nth,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
||||
}
|
||||
|
||||
if (ne00 > nth) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_kargs_cumsum_blk args = {
|
||||
/*.ne00 =*/ net0,
|
||||
/*.ne01 =*/ net1,
|
||||
/*.ne02 =*/ net2,
|
||||
/*.ne03 =*/ net3,
|
||||
/*.nb00 =*/ nbt0,
|
||||
/*.nb01 =*/ nbt1,
|
||||
/*.nb02 =*/ nbt2,
|
||||
/*.nb03 =*/ nbt3,
|
||||
/*.net0 =*/ net0,
|
||||
/*.net1 =*/ net1,
|
||||
/*.net2 =*/ net2,
|
||||
/*.net3 =*/ net3,
|
||||
/*.nbt0 =*/ nbt0,
|
||||
/*.nbt1 =*/ nbt1,
|
||||
/*.nbt2 =*/ nbt2,
|
||||
/*.nbt3 =*/ nbt3,
|
||||
/*.outb =*/ false,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
|
||||
}
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
||||
|
||||
ggml_metal_kargs_cumsum_add args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.net0 =*/ net0,
|
||||
/*.net1 =*/ net1,
|
||||
/*.net2 =*/ net2,
|
||||
/*.net3 =*/ net3,
|
||||
/*.nbt0 =*/ nbt0,
|
||||
/*.nbt1 =*/ nbt1,
|
||||
/*.nbt2 =*/ nbt2,
|
||||
/*.nbt3 =*/ nbt3,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_add);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -972,7 +1099,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
||||
|
||||
@@ -1017,7 +1144,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
||||
|
||||
@@ -1081,7 +1208,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
@@ -1169,7 +1296,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_ssm_conv args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1224,7 +1351,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const ggml_tensor * src3 = op->src[3];
|
||||
const ggml_tensor * src4 = op->src[4];
|
||||
@@ -1310,7 +1437,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
||||
const int64_t T = op->src[0]->ne[2];
|
||||
@@ -1351,7 +1478,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
@@ -1424,7 +1551,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t * opts = op->op_params;
|
||||
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
||||
@@ -1488,7 +1615,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
|
||||
@@ -1729,7 +1856,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
// src2 = ids
|
||||
GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
|
||||
@@ -1975,7 +2102,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
||||
// note: always reserve the padding space to avoid graph reallocations
|
||||
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
||||
const bool has_kvpad = true;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
||||
@@ -1984,7 +2113,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
} else {
|
||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
||||
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
||||
const bool has_kvpad = true;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
||||
@@ -2020,9 +2150,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
||||
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
|
||||
|
||||
// this optimization is not useful for the vector kernels
|
||||
if (is_vec) {
|
||||
return res;
|
||||
}
|
||||
// note: always reserve the blk buffer to avoid graph reallocations
|
||||
//if (is_vec) {
|
||||
// return res;
|
||||
//}
|
||||
|
||||
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
||||
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
||||
@@ -2049,13 +2180,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
||||
|
||||
size_t res = 0;
|
||||
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
// note: always reserve the temp buffer to avoid graph reallocations
|
||||
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
if (true) {
|
||||
const int64_t nwg = 32;
|
||||
const int64_t ne01_max = std::min(ne01, 32);
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -2184,8 +2318,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
@@ -2215,8 +2347,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
|
||||
}
|
||||
|
||||
if (need_sync) {
|
||||
@@ -2356,8 +2486,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
if (need_sync) {
|
||||
@@ -2688,7 +2816,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, op->op_params, sizeof(float));
|
||||
@@ -2736,7 +2864,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
|
||||
|
||||
@@ -2791,7 +2919,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, op->op_params, sizeof(float));
|
||||
@@ -2927,7 +3055,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
// make sure we have one or more position id(ne10) per token(ne02)
|
||||
GGML_ASSERT(ne10 % ne02 == 0);
|
||||
@@ -3021,7 +3149,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
|
||||
@@ -3171,7 +3299,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||
|
||||
@@ -3216,7 +3344,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||
|
||||
@@ -3270,7 +3398,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const float sf0 = (float)ne0/op->src[0]->ne[0];
|
||||
const float sf1 = (float)ne1/op->src[0]->ne[1];
|
||||
@@ -3323,7 +3451,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_pad args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -3367,7 +3495,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_pad_reflect_1d args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -3411,7 +3539,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float start;
|
||||
float step;
|
||||
@@ -3429,12 +3557,6 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
||||
|
||||
//[encoder setComputePipelineState:pipeline];
|
||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
||||
//[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
||||
|
||||
//[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
||||
@@ -3453,7 +3575,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int dim = op->op_params[0];
|
||||
const int max_period = op->op_params[1];
|
||||
@@ -3487,7 +3609,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_argmax args = {
|
||||
/*.ne00 = */ ne00,
|
||||
@@ -3523,38 +3645,93 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int64_t ne00_padded = 1;
|
||||
while (ne00_padded < ne00) {
|
||||
ne00_padded *= 2;
|
||||
}
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
||||
|
||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
||||
// bitonic sort requires the number of elements to be power of 2
|
||||
int nth = 1;
|
||||
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
const int npr = (ne00 + nth - 1)/nth;
|
||||
|
||||
// Metal kernels require the buffer size to be multiple of 16 bytes
|
||||
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
||||
const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
|
||||
const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += ggml_nbytes(op);
|
||||
|
||||
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
||||
std::swap(bid_dst, bid_tmp);
|
||||
}
|
||||
|
||||
ggml_metal_kargs_argsort args = {
|
||||
/*.ncols =*/ ne00,
|
||||
/*.ncols_pad =*/ ne00_padded
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
||||
|
||||
int len = nth;
|
||||
|
||||
while (len < ne00) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
ggml_metal_kargs_argsort_merge args_merge = {
|
||||
.ne00 = ne00,
|
||||
.ne01 = ne01,
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.nb00 = nb00,
|
||||
.nb01 = nb01,
|
||||
.nb02 = nb02,
|
||||
.nb03 = nb03,
|
||||
.len = len,
|
||||
};
|
||||
|
||||
// merges per row
|
||||
const int nm = (ne00 + 2*len - 1) / (2*len);
|
||||
|
||||
const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
|
||||
ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
std::swap(bid_dst, bid_tmp);
|
||||
|
||||
len <<= 1;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -3568,7 +3745,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float slope;
|
||||
memcpy(&slope, op->op_params, sizeof(float));
|
||||
@@ -3604,7 +3781,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
|
||||
@@ -3640,7 +3817,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -197,6 +197,11 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
||||
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
} break;
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
res *= 2;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1832,6 +1832,117 @@ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_blk(
|
||||
constant ggml_metal_kargs_cumsum_blk & args,
|
||||
device const char * src0,
|
||||
device char * tmp,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ib = tgpig[0]/args.ne01;
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * src0_row = (device const float *) (src0 +
|
||||
args.nb01*i01 +
|
||||
args.nb02*i02 +
|
||||
args.nb03*i03);
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
|
||||
float v = 0.0f;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
v = src0_row[i00 + tpitg.x];
|
||||
}
|
||||
|
||||
float s = simd_prefix_inclusive_sum(v);
|
||||
|
||||
if (tiisg == N_SIMDWIDTH - 1) {
|
||||
shmem_f32[sgitg] = s;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
s += shmem_f32[sgitg];
|
||||
|
||||
device float * dst_row = (device float *) dst +
|
||||
args.ne00*i01 +
|
||||
args.ne00*args.ne01*i02 +
|
||||
args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
dst_row[i00 + tpitg.x] = s;
|
||||
}
|
||||
|
||||
if (args.outb && tpitg.x == ntg.x - 1) {
|
||||
device float * tmp_row = (device float *) tmp +
|
||||
args.net0*i01 +
|
||||
args.net0*args.net1*i02 +
|
||||
args.net0*args.net1*args.net2*i03;
|
||||
|
||||
tmp_row[ib] = s;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_add(
|
||||
constant ggml_metal_kargs_cumsum_add & args,
|
||||
device const char * tmp,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ib = tgpig[0]/args.ne01;
|
||||
|
||||
if (ib == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * tmp_row = (device const float *) (tmp +
|
||||
args.nbt1*i01 +
|
||||
args.nbt2*i02 +
|
||||
args.nbt3*i03);
|
||||
|
||||
device float * dst_row = (device float *) dst +
|
||||
args.ne00*i01 +
|
||||
args.ne00*args.ne01*i02 +
|
||||
args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
@@ -4541,69 +4652,227 @@ kernel void kernel_timestep_embedding_f32(
|
||||
// bitonic sort implementation following the CUDA kernels as reference
|
||||
typedef void (argsort_t)(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const float * x,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]);
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template<ggml_sort_order order>
|
||||
kernel void kernel_argsort_f32_i32(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const float * x,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// bitonic sort
|
||||
int col = tpitg[0];
|
||||
int row = tgpig[1];
|
||||
const int col = tpitg[0];
|
||||
|
||||
if (col >= args.ncols_pad) return;
|
||||
const int i00 = (tgpig[0]/args.ne01)*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * x_row = x + row * args.ncols;
|
||||
threadgroup int32_t * dst_row = shared_values;
|
||||
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
shmem_i32[col] = i00 + col;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int k = 2; k <= args.ncols_pad; k *= 2) {
|
||||
for (int k = 2; k <= ntg.x; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= args.ncols ||
|
||||
(dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
if (shmem_i32[col] >= args.ne00 ||
|
||||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
|
||||
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= args.ncols ||
|
||||
(dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
if (shmem_i32[ixj] >= args.ne00 ||
|
||||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
|
||||
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < args.ncols) {
|
||||
dst[row * args.ncols + col] = dst_row[col];
|
||||
if (i00 + col < args.ne00) {
|
||||
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
dst[col] = shmem_i32[col];
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
|
||||
typedef void (argsort_merge_t)(
|
||||
constant ggml_metal_kargs_argsort_merge & args,
|
||||
device const char * src0,
|
||||
device const int32_t * tmp,
|
||||
device int32_t * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template<ggml_sort_order order>
|
||||
kernel void kernel_argsort_merge_f32_i32(
|
||||
constant ggml_metal_kargs_argsort_merge & args,
|
||||
device const char * src0,
|
||||
device const int32_t * tmp,
|
||||
device int32_t * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int im = tgpig[0] / args.ne01;
|
||||
const int i01 = tgpig[0] % args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
const int start = im * (2 * args.len);
|
||||
|
||||
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
|
||||
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
|
||||
|
||||
const int total = len0 + len1;
|
||||
|
||||
device const int32_t * tmp0 = tmp + start
|
||||
+ i01*args.ne00
|
||||
+ i02*args.ne00*args.ne01
|
||||
+ i03*args.ne00*args.ne01*args.ne02;
|
||||
|
||||
device const int32_t * tmp1 = tmp0 + args.len;
|
||||
|
||||
dst += start
|
||||
+ i01*args.ne00
|
||||
+ i02*args.ne00*args.ne01
|
||||
+ i03*args.ne00*args.ne01*args.ne02;
|
||||
|
||||
device const float * src0_row = (device const float *)(src0
|
||||
+ args.nb01*i01
|
||||
+ args.nb02*i02
|
||||
+ args.nb03*i03);
|
||||
|
||||
if (total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||
|
||||
const int k0 = tpitg.x * chunk;
|
||||
const int k1 = min(k0 + chunk, total);
|
||||
|
||||
if (k0 >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
int low = k0 > len1 ? k0 - len1 : 0;
|
||||
int high = MIN(k0, len0);
|
||||
|
||||
// binary-search partition (i, j) such that i + j = k
|
||||
while (low < high) {
|
||||
const int mid = (low + high) >> 1;
|
||||
|
||||
const int32_t idx0 = tmp0[mid];
|
||||
const int32_t idx1 = tmp1[k0 - mid - 1];
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
|
||||
bool take_left;
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
if (take_left) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
}
|
||||
|
||||
int i = low;
|
||||
int j = k0 - i;
|
||||
|
||||
// keep the merge fronts into registers
|
||||
int32_t idx0 = 0;
|
||||
float val0 = 0.0f;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
|
||||
int32_t idx1 = 0;
|
||||
float val1 = 0.0f;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
|
||||
for (int k = k0; k < k1; ++k) {
|
||||
int32_t out_idx;
|
||||
|
||||
if (i >= len0) {
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp1[j++];
|
||||
}
|
||||
break;
|
||||
} else if (j >= len1) {
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp0[i++];
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
bool take_left;
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
if (take_left) {
|
||||
out_idx = idx0;
|
||||
++i;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
} else {
|
||||
out_idx = idx1;
|
||||
++j;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst[k] = out_idx;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
|
||||
kernel void kernel_leaky_relu_f32(
|
||||
constant ggml_metal_kargs_leaky_relu & args,
|
||||
device const float * src0,
|
||||
@@ -6291,6 +6560,7 @@ template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_
|
||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
|
||||
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
|
||||
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
|
||||
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
|
||||
#endif
|
||||
|
||||
@@ -119,6 +119,7 @@ set(GGML_OPENCL_KERNELS
|
||||
pad
|
||||
repeat
|
||||
mul_mat_f16_f32
|
||||
mul_mm_f16_f32_kq_kqv
|
||||
conv2d
|
||||
conv2d_f16_f32
|
||||
flash_attn_f32_f16
|
||||
|
||||
@@ -407,6 +407,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program_mul_mv_f32_f32;
|
||||
cl_program program_mul;
|
||||
cl_program program_mul_mat_f16_f32_tiled;
|
||||
cl_program program_mul_mm_f16_f32_kqv;
|
||||
cl_program program_mul_mm_f16_f32_kq;
|
||||
cl_program program_div;
|
||||
cl_program program_sub;
|
||||
cl_program program_norm;
|
||||
@@ -481,6 +483,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mat_f16_f32;
|
||||
cl_kernel kernel_mul_mat_f16_f32_l4;
|
||||
cl_kernel kernel_mul_mat_f16_f32_tiled;
|
||||
cl_kernel kernel_mul_mm_f16_f32_kqv;
|
||||
cl_kernel kernel_mul_mm_f16_f32_kq;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
|
||||
@@ -1235,6 +1239,25 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_f16_f32_kq_kqv
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_f16_f32_kq_kqv.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_f16_f32_kq_kqv.cl");
|
||||
#endif
|
||||
backend_ctx->program_mul_mm_f16_f32_kqv =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+" -DKQV ");
|
||||
backend_ctx->program_mul_mm_f16_f32_kq =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, "mul_mm_f16_f32_kqv", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, "mul_mm_f16_f32_kq", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -5682,7 +5705,7 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
|
||||
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
|
||||
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs, NULL));
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
@@ -6665,6 +6688,146 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
|
||||
const cl_ulong nb01 = src0->nb[1];
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
|
||||
const int ne10 = src1->ne[0];
|
||||
const int ne11 = src1->ne[1];
|
||||
const int ne12 = src1->ne[2];
|
||||
|
||||
const cl_ulong nb10 = src1->nb[0];
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
|
||||
cl_kernel kernel;
|
||||
cl_context context = backend_ctx->context;
|
||||
|
||||
cl_int status;
|
||||
cl_image_format img_fmt_1d;
|
||||
cl_image_desc img_desc_1d;
|
||||
cl_buffer_region region;
|
||||
cl_mem A_image1d;
|
||||
cl_mem A_sub_buffer;
|
||||
cl_mem B_sub_buffer;
|
||||
cl_mem D_image1d;
|
||||
cl_mem D_sub_buffer;
|
||||
|
||||
int M = ne01;
|
||||
int N = ne1;
|
||||
int K = ne00;
|
||||
|
||||
if (nb01 > nb02) {
|
||||
// KQ
|
||||
kernel = backend_ctx->kernel_mul_mm_f16_f32_kq;
|
||||
} else {
|
||||
// KQV
|
||||
kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv;
|
||||
}
|
||||
// create sub-buffer for A
|
||||
// <--------------------------------------------> //
|
||||
extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra;
|
||||
|
||||
region.origin = (extra0->offset);
|
||||
if (nb01 > nb02) {
|
||||
// KQ
|
||||
region.size = nb01 * ne01;
|
||||
} else {
|
||||
// KQV
|
||||
region.size = nb02 * ne02;
|
||||
}
|
||||
|
||||
A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// <--------------------------------------------> //
|
||||
|
||||
// create sub-buffer for B
|
||||
// <--------------------------------------------> //
|
||||
region.origin = (extra1->offset);
|
||||
region.size = nb10 * ne10 * ne11 * ne12;
|
||||
B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
// <--------------------------------------------> //
|
||||
|
||||
img_fmt_1d = {CL_RGBA, CL_FLOAT};
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
if (nb01 > nb02) {
|
||||
img_desc_1d.image_width = (nb01 * ne01 / 4)/4;
|
||||
}
|
||||
else {
|
||||
img_desc_1d.image_width = (nb02 * ne02 / 4)/4;
|
||||
}
|
||||
img_desc_1d.buffer = A_sub_buffer;
|
||||
A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// create sub-buffer for output C
|
||||
// <--------------------------------------------> //
|
||||
region.origin = (extrad->offset);
|
||||
region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes
|
||||
D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
// <--------------------------------------------> //
|
||||
|
||||
// create image for C output
|
||||
// <--------------------------------------------> //
|
||||
img_fmt_1d = {CL_R, CL_FLOAT};
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4;
|
||||
img_desc_1d.buffer = D_sub_buffer;
|
||||
D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
// <--------------------------------------------> //
|
||||
|
||||
int offset_src0 = 0;
|
||||
int offset_src1 = 0;
|
||||
|
||||
// set kernel args
|
||||
// <--------------------------------------------> //
|
||||
cl_uint k_arg = 0;
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src0));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src1));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01));
|
||||
|
||||
size_t global_work_size[3] = {64, static_cast<size_t>(((M+63)/64)), static_cast<size_t>(((N+31)/32)*ne12)};
|
||||
size_t local_work_size[3] = {64, 1, 2};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
// deallocate sub buffers and images
|
||||
// <--------------------------------------------> //
|
||||
CL_CHECK(clReleaseMemObject(A_image1d));
|
||||
CL_CHECK(clReleaseMemObject(D_image1d));
|
||||
CL_CHECK(clReleaseMemObject(A_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(B_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(D_sub_buffer));
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -6731,6 +6894,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
cl_context context = backend_ctx->context;
|
||||
|
||||
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){
|
||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) {
|
||||
|
||||
// init CL objects
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
#define LM_FIRST_256B 0
|
||||
#define LM_SECOND_256B 64
|
||||
#define LM_THIRD_256B 128
|
||||
#define LM_FOURTH_256B 192
|
||||
|
||||
|
||||
inline float16 mm_load_a(
|
||||
image1d_buffer_t matrix_A,
|
||||
uint subMatrixAStartInElements,
|
||||
int nb01,
|
||||
int line_stride_matrix_A_in_bytes
|
||||
) {
|
||||
__private float8 regA;
|
||||
size_t sub_block_id_m = get_local_id(0);
|
||||
|
||||
#ifdef KQV
|
||||
uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4);
|
||||
#else // KQ
|
||||
uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4);
|
||||
#endif
|
||||
|
||||
regA.s0123 = read_imagef(matrix_A, a_texCoord/4);
|
||||
regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4);
|
||||
|
||||
return convert_float16(as_half16(regA));
|
||||
}
|
||||
|
||||
inline float4 alu_32(
|
||||
float16 regA,
|
||||
__local float4* matrix_B_vec
|
||||
) {
|
||||
|
||||
__private float4 rC = 0;
|
||||
int i = get_sub_group_id() * 64;
|
||||
|
||||
rC += regA.s0 * matrix_B_vec[i];
|
||||
rC += regA.s1 * matrix_B_vec[i + 16];
|
||||
rC += regA.s4 * matrix_B_vec[i + 1];
|
||||
rC += regA.s5 * matrix_B_vec[i + 17];
|
||||
rC += regA.s8 * matrix_B_vec[i + 2];
|
||||
rC += regA.s9 * matrix_B_vec[i + 18];
|
||||
rC += regA.sc * matrix_B_vec[i + 3];
|
||||
rC += regA.sd * matrix_B_vec[i + 19];
|
||||
|
||||
i += 32;
|
||||
|
||||
rC += regA.s2 * matrix_B_vec[i];
|
||||
rC += regA.s3 * matrix_B_vec[i + 16];
|
||||
rC += regA.s6 * matrix_B_vec[i + 1];
|
||||
rC += regA.s7 * matrix_B_vec[i + 17];
|
||||
rC += regA.sa * matrix_B_vec[i + 2];
|
||||
rC += regA.sb * matrix_B_vec[i + 18];
|
||||
rC += regA.se * matrix_B_vec[i + 3];
|
||||
rC += regA.sf * matrix_B_vec[i + 19];
|
||||
|
||||
return rC;
|
||||
}
|
||||
|
||||
inline float16 alu_16(
|
||||
float16 regA,
|
||||
__local float* matrix_B_local
|
||||
) {
|
||||
float16 out;
|
||||
__local float4* matrix_B_vec = (__local float4*)matrix_B_local;
|
||||
|
||||
out.s0123 = alu_32(regA, matrix_B_vec);
|
||||
out.s4567 = alu_32(regA, matrix_B_vec + 4);
|
||||
out.s89ab = alu_32(regA, matrix_B_vec + 8);
|
||||
out.scdef = alu_32(regA, matrix_B_vec + 12);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
inline void mm_mad(
|
||||
__local float* matrix_B_local,
|
||||
float16 regA,
|
||||
float8 regB,
|
||||
uint b_localOffsetInWords,
|
||||
float16* regC0_ptr,
|
||||
float16* regC1_ptr
|
||||
) {
|
||||
int offset = b_localOffsetInWords + get_sub_group_id() * 256;
|
||||
|
||||
matrix_B_local[offset + LM_FIRST_256B] = regB.s0;
|
||||
matrix_B_local[offset + LM_SECOND_256B] = regB.s1;
|
||||
matrix_B_local[offset + LM_THIRD_256B] = regB.s2;
|
||||
matrix_B_local[offset + LM_FOURTH_256B] = regB.s3;
|
||||
|
||||
float16 add0 = alu_16(regA, matrix_B_local);
|
||||
*regC0_ptr += add0;
|
||||
|
||||
matrix_B_local[offset + LM_FIRST_256B] = regB.s4;
|
||||
matrix_B_local[offset + LM_SECOND_256B] = regB.s5;
|
||||
matrix_B_local[offset + LM_THIRD_256B] = regB.s6;
|
||||
matrix_B_local[offset + LM_FOURTH_256B] = regB.s7;
|
||||
|
||||
float16 add1 = alu_16(regA, matrix_B_local);
|
||||
*regC1_ptr += add1;
|
||||
}
|
||||
|
||||
inline void mm_store_c_N(
|
||||
__write_only image1d_buffer_t matrix_C,
|
||||
float16 regC0,
|
||||
float16 regC1,
|
||||
uint subMatrixCStartInElements,
|
||||
int line_stride_matrix_C_in_bytes,
|
||||
int mask
|
||||
) {
|
||||
size_t sub_block_id_m = get_local_id(0);
|
||||
|
||||
uint strideInWords = line_stride_matrix_C_in_bytes/4;
|
||||
uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m);
|
||||
|
||||
uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords;
|
||||
uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords;
|
||||
uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords;
|
||||
uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords;
|
||||
uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords;
|
||||
uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords;
|
||||
uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords;
|
||||
uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords;
|
||||
uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords;
|
||||
uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords;
|
||||
uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords;
|
||||
uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords;
|
||||
uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords;
|
||||
uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords;
|
||||
uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords;
|
||||
uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords;
|
||||
uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords;
|
||||
uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords;
|
||||
uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords;
|
||||
uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords;
|
||||
uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords;
|
||||
uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords;
|
||||
uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords;
|
||||
uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords;
|
||||
uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords;
|
||||
uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords;
|
||||
uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords;
|
||||
uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords;
|
||||
uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords;
|
||||
uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords;
|
||||
uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords;
|
||||
|
||||
if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); }
|
||||
if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); }
|
||||
if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); }
|
||||
if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); }
|
||||
if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); }
|
||||
if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); }
|
||||
if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); }
|
||||
if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); }
|
||||
if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); }
|
||||
if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); }
|
||||
if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); }
|
||||
if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); }
|
||||
if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); }
|
||||
if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); }
|
||||
if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); }
|
||||
if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); }
|
||||
if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); }
|
||||
if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); }
|
||||
if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); }
|
||||
if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); }
|
||||
if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); }
|
||||
if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); }
|
||||
if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); }
|
||||
if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); }
|
||||
if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); }
|
||||
if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); }
|
||||
if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); }
|
||||
if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); }
|
||||
if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); }
|
||||
if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); }
|
||||
if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); }
|
||||
if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); }
|
||||
}
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
#ifdef KQV
|
||||
__kernel void mul_mm_f16_f32_kqv(
|
||||
#else
|
||||
__kernel void mul_mm_f16_f32_kq(
|
||||
#endif
|
||||
__read_only image1d_buffer_t matrix_A,
|
||||
int offset0,
|
||||
__global float* matrix_B,
|
||||
int offset1,
|
||||
__write_only image1d_buffer_t matrix_C,
|
||||
int offsetd,
|
||||
int M, int K, int N,
|
||||
int D_A,
|
||||
int D_B,
|
||||
int nb01
|
||||
) {
|
||||
|
||||
uint block_id_m = get_global_id(1);
|
||||
uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N);
|
||||
uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N);
|
||||
|
||||
__private float16 regA;
|
||||
__private float8 regB;
|
||||
__private float16 regC0;
|
||||
__private float16 regC1;
|
||||
|
||||
const uint col = block_id_m * TILESIZE_M;
|
||||
const uint row = block_id_n * TILESIZE_N;
|
||||
const uint depth_A = block_id_d / (D_B/D_A);
|
||||
const uint depth_B = block_id_d;
|
||||
|
||||
#ifdef KQV
|
||||
int line_stride_matrix_A_in_bytes = nb01 * M;
|
||||
int line_stride_matrix_B_in_bytes = K * N * 4;
|
||||
#else
|
||||
int line_stride_matrix_A_in_bytes = K * D_A * 2;
|
||||
int line_stride_matrix_B_in_bytes = K * D_B * 4;
|
||||
#endif
|
||||
|
||||
int line_stride_matrix_C_in_bytes = M * 4;
|
||||
|
||||
const uint strideAinElements = line_stride_matrix_A_in_bytes / 2;
|
||||
const uint strideBinElements = line_stride_matrix_B_in_bytes / 4;
|
||||
|
||||
size_t sub_block_id_m = get_local_id(0);
|
||||
|
||||
uint b_localOffsetInWords = (sub_block_id_m/16)*16
|
||||
+ ((((sub_block_id_m)>>0)&1)<<2)
|
||||
+ ((((sub_block_id_m)>>1)&1)<<3)
|
||||
+ ((((sub_block_id_m)>>2)&1)<<0)
|
||||
+ ((((sub_block_id_m)>>3)&1)<<1);
|
||||
|
||||
uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)};
|
||||
uint b_globalOffsetInWords00, b_globalOffsetInWords16;
|
||||
#ifdef KQV
|
||||
b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K;
|
||||
b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K);
|
||||
uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2;
|
||||
uint subMatrixBStartInElements = depth_B * strideBinElements + row * K;
|
||||
#else
|
||||
b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4;
|
||||
b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4);
|
||||
uint subMatrixAStartInElements = col * strideAinElements + depth_A * K;
|
||||
uint subMatrixBStartInElements = row * strideBinElements + depth_B * K;
|
||||
#endif
|
||||
|
||||
__local float matrix_B_local[1024];
|
||||
|
||||
for (uint step=0; step < K; step+=TILESIZE_K) {
|
||||
size_t sub_block_id_m = get_local_id(0);
|
||||
regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes);
|
||||
|
||||
uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00;
|
||||
uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16;
|
||||
|
||||
regB.s0123 = vload4(b_coordInWords00/4, matrix_B);
|
||||
regB.s4567 = vload4(b_coordInWords16/4, matrix_B);
|
||||
|
||||
mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, ®C0, ®C1);
|
||||
|
||||
subMatrixAStartInElements += TILESIZE_K;
|
||||
subMatrixBStartInElements += TILESIZE_K;
|
||||
}
|
||||
|
||||
uint subMatrixCStartInElements = depth_B * N * M + row * M + col;
|
||||
mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32));
|
||||
}
|
||||
|
||||
@@ -134,6 +134,15 @@ kernel void kernel_rms_norm_mul(
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
// The size of sum is sizeof(float)*subgroup_size.
|
||||
// Each subgroup writes its partial sum to this array.
|
||||
// So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
|
||||
// This is generally true -
|
||||
// for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
|
||||
if (get_sub_group_id() == 0) {
|
||||
sum[get_sub_group_local_id()] = 0.0f;
|
||||
}
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
@@ -148,24 +157,30 @@ kernel void kernel_rms_norm_mul(
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
sum[get_sub_group_id()] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
|
||||
if (get_local_id(0) < i) {
|
||||
sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
||||
}
|
||||
}
|
||||
if (get_local_id(0) == 0) {
|
||||
sum[0] /= ne00;
|
||||
}
|
||||
//for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
|
||||
// if (get_local_id(0) < i) {
|
||||
// sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
||||
// }
|
||||
//}
|
||||
//if (get_local_id(0) == 0) {
|
||||
// sum[0] /= ne00;
|
||||
//}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
//barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
float mean = sum[0];
|
||||
sumf = sum[get_sub_group_local_id()];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
float mean = sumf / ne00;
|
||||
float scale = 1.0f/sqrt(mean + eps);
|
||||
|
||||
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
+111
-249
@@ -170,73 +170,31 @@ static __dpct_inline__ T op_trunc(T x) {
|
||||
return sycl::trunc(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_sgn(x[i]);
|
||||
}
|
||||
}
|
||||
template<typename T, typename F>
|
||||
static void unary_op_generic_kernel(
|
||||
const T * x,
|
||||
T * dst,
|
||||
const int k,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,
|
||||
const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
const size_t nbd0, const size_t nbd1, const size_t nbd2, const size_t nbd3,
|
||||
const sycl::nd_item<1> & item_ct1,
|
||||
F func) {
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
(void) ne3;
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_abs(x[i]);
|
||||
}
|
||||
}
|
||||
const int64_t i0 = i % ne0;
|
||||
const int64_t i1 = (i / ne0) % ne1;
|
||||
const int64_t i2 = (i / (ne0*ne1)) % ne2;
|
||||
const int64_t i3 = i / (ne0*ne1*ne2);
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_elu(x[i]);
|
||||
}
|
||||
}
|
||||
const char * src_base = (const char *) x;
|
||||
char * dst_base = (char *) dst;
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_gelu(x[i]);
|
||||
}
|
||||
}
|
||||
const T * srcp = (const T *)(src_base + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 );
|
||||
T * dstp = (T *)(dst_base + i0*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3);
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_silu(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_gelu_quick(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_gelu_erf(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_tanh(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_relu(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_sigmoid(x[i]);
|
||||
*dstp = func(*srcp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,27 +219,6 @@ static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::n
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_hardsigmoid(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_hardswish(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_exp(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
@@ -289,19 +226,6 @@ static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::n
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_neg(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_step(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
|
||||
@@ -620,6 +544,48 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
||||
}
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
static inline void ggml_sycl_op_unary(
|
||||
ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {
|
||||
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
|
||||
const size_t nb0 = src0->nb[0];
|
||||
const size_t nb1 = src0->nb[1];
|
||||
const size_t nb2 = src0->nb[2];
|
||||
const size_t nb3 = src0->nb[3];
|
||||
|
||||
const size_t nbd0 = dst->nb[0];
|
||||
const size_t nbd1 = dst->nb[1];
|
||||
const size_t nbd2 = dst->nb[2];
|
||||
const size_t nbd3 = dst->nb[3];
|
||||
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[=](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_generic_kernel(
|
||||
src, dst_ptr, k_elements,
|
||||
ne0, ne1, ne2, ne3,
|
||||
nb0, nb1, nb2, nb3,
|
||||
nbd0, nbd1, nbd2, nbd3,
|
||||
item_ct1,
|
||||
func
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
@@ -645,159 +611,75 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_sgn(x);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_abs(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_elu(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_silu(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_gelu(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_gelu_quick(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_gelu_erf(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_tanh(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_relu(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_hardsigmoid(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_hardswish(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_exp(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
@@ -814,42 +696,22 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_neg(x);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_step(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_sigmoid(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -4360,21 +4360,22 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
}
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
return true;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,21 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[i] = D_TYPE(abs(float(data_a[i])));
|
||||
}
|
||||
@@ -7,6 +7,7 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
@@ -108,6 +109,38 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
float Sf[Br][cols_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
@@ -153,21 +186,6 @@ void main() {
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
|
||||
@@ -148,6 +149,37 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
float mask_cache[Bc * Br / WorkGroupSize];
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
mask_cache[idx / WorkGroupSize] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
@@ -208,7 +240,8 @@ void main() {
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
float f = mask_cache[idx / WorkGroupSize];
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||
return max(x, y);
|
||||
}
|
||||
|
||||
float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
|
||||
return max(x, y);
|
||||
}
|
||||
|
||||
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||
return x;
|
||||
}
|
||||
@@ -142,6 +146,44 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
if (nem1_bounds_check) {
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
@@ -158,31 +200,7 @@ void main() {
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
if (nem1_bounds_check) {
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
}
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
}
|
||||
|
||||
// Clear padding elements to -inf, so they don't contribute to rowmax
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));
|
||||
}
|
||||
@@ -11,29 +11,7 @@
|
||||
#define EXPERT_COUNT 8
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
#ifndef MMQ
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#else
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
#ifdef B_TYPE_VEC2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||
#endif
|
||||
#ifdef B_TYPE_VEC4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 4) readonly buffer IDS {int data_ids[];};
|
||||
#endif
|
||||
#include "mul_mat_vec_iface.glsl"
|
||||
|
||||
#include "dequant_funcs.glsl"
|
||||
|
||||
@@ -48,8 +26,7 @@ layout (push_constant) uniform parameter
|
||||
uint batch_stride_b;
|
||||
uint batch_stride_d;
|
||||
|
||||
uint enable_bias;
|
||||
uint enable_scale;
|
||||
uint fusion_flags;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint nei0;
|
||||
@@ -123,17 +100,24 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
if (p.enable_bias != 0) {
|
||||
#ifdef MUL_MAT_ID
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
|
||||
#else
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
#ifdef MUL_MAT_ID
|
||||
if (p.enable_scale != 0) {
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
#endif
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
@@ -171,17 +155,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
temp[j][n] += tmpsh[j][n][s];
|
||||
}
|
||||
if (p.enable_bias != 0) {
|
||||
#ifdef MUL_MAT_ID
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
|
||||
#else
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
#ifdef MUL_MAT_ID
|
||||
if (p.enable_scale != 0) {
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
#endif
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
@@ -209,17 +200,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
if (p.enable_bias != 0) {
|
||||
#ifdef MUL_MAT_ID
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
|
||||
#else
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
#ifdef MUL_MAT_ID
|
||||
if (p.enable_scale != 0) {
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
#endif
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
#include "types.glsl"
|
||||
|
||||
#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
|
||||
#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
|
||||
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
|
||||
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
|
||||
|
||||
#ifndef MMQ
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_VEC4)
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
#endif
|
||||
#else
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
#ifdef B_TYPE_VEC2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||
#endif
|
||||
#ifdef B_TYPE_VEC4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Fuse0 {D_TYPE data_fuse0[];};
|
||||
layout (binding = 4) readonly buffer Fuse1 {D_TYPE data_fuse1[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 5) readonly buffer IDS {int data_ids[];};
|
||||
#endif
|
||||
|
||||
@@ -8,14 +8,7 @@
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
||||
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
|
||||
#include "mul_mat_vec_iface.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
@@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
|
||||
uint nb03;
|
||||
uint nb13;
|
||||
uint nb23;
|
||||
uint enable_bias;
|
||||
uint fusion_flags;
|
||||
} p;
|
||||
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
@@ -120,9 +113,12 @@ void main() {
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
if (p.enable_bias != 0) {
|
||||
tmp[0] += FLOAT_TYPE(data_bias[idst]);
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
tmp[0] += FLOAT_TYPE(data_fuse0[idst]);
|
||||
}
|
||||
dst[idst] = tmp[0];
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
tmp[0] += FLOAT_TYPE(data_fuse1[idst]);
|
||||
}
|
||||
data_d[idst] = tmp[0];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,14 +10,7 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
||||
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
|
||||
#include "mul_mat_vec_iface.glsl"
|
||||
|
||||
layout(constant_id = 0) const int BLOCK_SIZE = 32;
|
||||
// gqa_ratio is in the range [1,8]
|
||||
@@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
|
||||
uint nchannels_y;
|
||||
uint b_offset;
|
||||
uint d_offset;
|
||||
uint enable_bias;
|
||||
uint fusion_flags;
|
||||
} p;
|
||||
|
||||
#if !USE_SUBGROUP_ADD
|
||||
@@ -151,10 +144,13 @@ void main() {
|
||||
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
|
||||
// dst is not transposed and not permuted
|
||||
const uint idst = (channel + c)*nrows_dst + row_dst;
|
||||
if (p.enable_bias != 0) {
|
||||
temp[c] += FLOAT_TYPE(data_bias[idst]);
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[c] += FLOAT_TYPE(data_fuse0[idst]);
|
||||
}
|
||||
dst[idst] = temp[c];
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
temp[c] += FLOAT_TYPE(data_fuse1[idst]);
|
||||
}
|
||||
data_d[idst] = temp[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
|
||||
buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,21 +345,22 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
|
||||
// Repack 2x4 quants into one int
|
||||
// Add the 3rd bit instead of subtracting it to allow packing the quants
|
||||
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
// vec4 for unpack8 used due to #12147
|
||||
const i8vec2 vals00 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
||||
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
||||
const i8vec2 vals01 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
||||
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
||||
const i8vec2 vals10 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
||||
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
||||
const i8vec2 vals11 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
||||
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
|
||||
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
|
||||
const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
|
||||
}
|
||||
@@ -516,15 +517,15 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
|
||||
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals00 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))).xy |
|
||||
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
|
||||
const i8vec2 vals01 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))).xy |
|
||||
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
|
||||
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
|
||||
const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
data_d[i] = D_TYPE(-float(data_a[i]));
|
||||
}
|
||||
@@ -802,6 +802,9 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
@@ -827,6 +830,8 @@ void process_shaders() {
|
||||
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("neg_f16", "neg.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("neg_f32", "neg.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
@@ -835,6 +840,8 @@ void process_shaders() {
|
||||
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
|
||||
+154
-5
@@ -935,6 +935,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"COS",
|
||||
"SUM",
|
||||
"SUM_ROWS",
|
||||
"CUMSUM",
|
||||
"MEAN",
|
||||
"ARGMAX",
|
||||
"COUNT_EQUAL",
|
||||
@@ -990,6 +991,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"TIMESTEP_EMBEDDING",
|
||||
"ARGSORT",
|
||||
"LEAKY_RELU",
|
||||
"TRI",
|
||||
"FILL",
|
||||
|
||||
"FLASH_ATTN_EXT",
|
||||
"FLASH_ATTN_BACK",
|
||||
@@ -1002,6 +1005,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"RWKV_WKV6",
|
||||
"GATED_LINEAR_ATTN",
|
||||
"RWKV_WKV7",
|
||||
"SOLVE_TRI",
|
||||
|
||||
"UNARY",
|
||||
|
||||
@@ -1019,7 +1023,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1039,6 +1043,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cos(x)",
|
||||
"Σx",
|
||||
"Σx_k",
|
||||
"cumsum(x)",
|
||||
"Σx/n",
|
||||
"argmax(x)",
|
||||
"count_equal(x)",
|
||||
@@ -1094,6 +1099,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"timestep_embedding(timesteps, dim, max_period)",
|
||||
"argsort(x)",
|
||||
"leaky_relu(x)",
|
||||
"tri(x)",
|
||||
"fill(x, c)",
|
||||
|
||||
"flash_attn_ext(x)",
|
||||
"flash_attn_back(x)",
|
||||
@@ -1106,6 +1113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||
"gated_linear_attn(k, v, q, gate, s)",
|
||||
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
||||
"A X = B, A triangular, solve X",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
@@ -1123,7 +1131,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
|
||||
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -1142,6 +1150,8 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||
"HARDSWISH",
|
||||
"HARDSIGMOID",
|
||||
"EXP",
|
||||
"EXPM1",
|
||||
"SOFTPLUS",
|
||||
"GELU_ERF",
|
||||
"XIELU",
|
||||
"FLOOR",
|
||||
@@ -1150,7 +1160,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||
"TRUNC",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 20, "GGML_UNARY_OP_COUNT != 20");
|
||||
static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22");
|
||||
|
||||
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
||||
"REGLU",
|
||||
@@ -2258,6 +2268,30 @@ struct ggml_tensor * ggml_log_inplace(
|
||||
return ggml_log_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_expm1(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_EXPM1);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_expm1_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXPM1);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_softplus(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_softplus_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS);
|
||||
}
|
||||
|
||||
// ggml_sin
|
||||
|
||||
static struct ggml_tensor * ggml_sin_impl(
|
||||
@@ -2341,6 +2375,21 @@ struct ggml_tensor * ggml_sum_rows(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_cumsum
|
||||
|
||||
struct ggml_tensor * ggml_cumsum(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_CUMSUM;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_mean
|
||||
|
||||
struct ggml_tensor * ggml_mean(
|
||||
@@ -2668,8 +2717,8 @@ struct ggml_tensor * ggml_xielu(
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
|
||||
ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
|
||||
ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
|
||||
ggml_set_op_params_f32(result, 1, beta + ggml_compute_softplus_f32(alpha_n));
|
||||
ggml_set_op_params_f32(result, 2, ggml_compute_softplus_f32(alpha_p));
|
||||
ggml_set_op_params_f32(result, 3, beta);
|
||||
ggml_set_op_params_f32(result, 4, eps);
|
||||
|
||||
@@ -5028,6 +5077,61 @@ struct ggml_tensor * ggml_timestep_embedding(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_tri
|
||||
|
||||
struct ggml_tensor * ggml_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_tri_type type) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(a->ne[0] == a->ne[1]);
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, type);
|
||||
|
||||
result->op = GGML_OP_TRI;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_fill
|
||||
|
||||
static struct ggml_tensor * ggml_fill_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, c);
|
||||
|
||||
result->op = GGML_OP_FILL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_fill(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c) {
|
||||
return ggml_fill_impl(ctx, a, c, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_fill_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float c) {
|
||||
return ggml_fill_impl(ctx, a, c, true);
|
||||
}
|
||||
|
||||
// ggml_argsort
|
||||
|
||||
struct ggml_tensor * ggml_argsort(
|
||||
@@ -5882,6 +5986,41 @@ struct ggml_tensor * ggml_opt_step_sgd(
|
||||
return result;
|
||||
}
|
||||
|
||||
// solve_tri
|
||||
|
||||
struct ggml_tensor * ggml_solve_tri(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
bool left,
|
||||
bool lower,
|
||||
bool uni) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
|
||||
// A must be square and lower diagonal
|
||||
GGML_ASSERT(a->ne[0] == a->ne[1]);
|
||||
// B must have same outer dimension as A
|
||||
GGML_ASSERT(a->ne[1] == b->ne[1]);
|
||||
|
||||
// batch dimensions must be equal
|
||||
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
||||
GGML_ASSERT(a->ne[3] == b->ne[3]);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(ggml_is_contiguous(b));
|
||||
|
||||
GGML_ASSERT(lower && left && !uni); // TODO: support other variants
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
|
||||
|
||||
result->op = GGML_OP_SOLVE_TRI;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
||||
@@ -6454,6 +6593,16 @@ static void ggml_compute_backward(
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXPM1: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_exp(ctx, src0)));
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_SOFTPLUS: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sigmoid(ctx, src0)));
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
|
||||
__func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
|
||||
|
||||
@@ -409,6 +409,7 @@ class MODEL_ARCH(IntEnum):
|
||||
BAILINGMOE2 = auto()
|
||||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
AFMOE = auto()
|
||||
ERNIE4_5 = auto()
|
||||
ERNIE4_5_MOE = auto()
|
||||
HUNYUAN_MOE = auto()
|
||||
@@ -464,6 +465,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_POST_NORM = auto()
|
||||
ATTN_ROT_EMBD = auto()
|
||||
ATTN_SINKS = auto()
|
||||
ATTN_GATE = auto()
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
@@ -776,6 +778,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
|
||||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
MODEL_ARCH.AFMOE: "afmoe",
|
||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
|
||||
MODEL_ARCH.FALCON_H1: "falcon-h1",
|
||||
@@ -828,6 +831,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
|
||||
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
||||
@@ -2693,6 +2697,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.AFMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_GATE,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
MODEL_ARCH.ERNIE4_5: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@@ -314,6 +314,10 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.sinks", # openai-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_GATE: (
|
||||
"model.layers.{bid}.self_attn.gate_proj", # afmoe
|
||||
),
|
||||
|
||||
# Feed-forward norm
|
||||
MODEL_TENSOR.FFN_NORM: (
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||
@@ -340,11 +344,12 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.feedforward_layernorm", # apertus
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
# Pre feed-forward norm
|
||||
MODEL_TENSOR.FFN_PRE_NORM: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
|
||||
"model.layers.{bid}.pre_ff_layernorm.weight",
|
||||
"model.layers.{bid}.pre_mlp_layernorm", # afmoe
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
@@ -370,6 +375,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
||||
"model.layers.{bid}.mlp.router.gate", # afmoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
@@ -380,6 +386,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
||||
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
||||
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
|
||||
"model.layers.{bid}.mlp.expert_bias", # afmoe
|
||||
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
|
||||
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
|
||||
),
|
||||
|
||||
@@ -35,6 +35,7 @@ add_library(llama
|
||||
unicode-data.cpp
|
||||
unicode.cpp
|
||||
unicode.h
|
||||
models/afmoe.cpp
|
||||
models/apertus.cpp
|
||||
models/arcee.cpp
|
||||
models/arctic.cpp
|
||||
|
||||
@@ -90,6 +90,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_AFMOE, "afmoe" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
@@ -333,6 +334,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_AFMOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA4,
|
||||
{
|
||||
@@ -2444,6 +2475,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
||||
@@ -94,6 +94,7 @@ enum llm_arch {
|
||||
LLM_ARCH_BAILINGMOE2,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_AFMOE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_ERNIE4_5_MOE,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
@@ -312,6 +313,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_ATTN_SINKS,
|
||||
LLM_TENSOR_ATTN_GATE,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
|
||||
@@ -84,6 +84,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_15B: return "15B";
|
||||
case LLM_TYPE_16B: return "16B";
|
||||
case LLM_TYPE_20B: return "20B";
|
||||
case LLM_TYPE_26B: return "26B";
|
||||
case LLM_TYPE_27B: return "27B";
|
||||
case LLM_TYPE_30B: return "30B";
|
||||
case LLM_TYPE_32B: return "32B";
|
||||
@@ -695,6 +696,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
|
||||
// Set up interleaved sliding window attention (ISWA)
|
||||
// Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4)
|
||||
if (hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.set_swa_pattern(4);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
}
|
||||
|
||||
// Default to sigmoid if not set
|
||||
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 56: type = LLM_TYPE_6B; break;
|
||||
case 32: type = LLM_TYPE_26B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DECI:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@@ -5749,6 +5781,71 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
const int64_t n_expert_shared = hparams.n_expert_shared;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
// dual attention normalization
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// attention projections
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
// Q/K normalization
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
// attention gating
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
|
||||
// dual ffn normalization
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) {
|
||||
// MoE layers
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
|
||||
|
||||
// grouped expert weights
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
// shared expert
|
||||
if (n_expert_shared > 0) {
|
||||
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
}
|
||||
} else {
|
||||
// Dense layers
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_ERNIE4_5_MOE:
|
||||
{
|
||||
@@ -7243,6 +7340,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_arcee>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_afmoe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
{
|
||||
llm = std::make_unique<llm_build_ernie4_5>(*this, params);
|
||||
@@ -7528,6 +7629,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_MINIMAX_M2:
|
||||
case LLM_ARCH_COGVLM:
|
||||
case LLM_ARCH_PANGU_EMBED:
|
||||
case LLM_ARCH_AFMOE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
@@ -76,6 +76,7 @@ enum llm_type {
|
||||
LLM_TYPE_15B,
|
||||
LLM_TYPE_16B,
|
||||
LLM_TYPE_20B,
|
||||
LLM_TYPE_26B,
|
||||
LLM_TYPE_27B,
|
||||
LLM_TYPE_30B,
|
||||
LLM_TYPE_32B,
|
||||
@@ -234,6 +235,7 @@ struct llama_layer {
|
||||
struct ggml_tensor * wk_enc = nullptr;
|
||||
struct ggml_tensor * wv_enc = nullptr;
|
||||
struct ggml_tensor * wo_enc = nullptr;
|
||||
struct ggml_tensor * wqkv_gate = nullptr;
|
||||
|
||||
// attention bias
|
||||
struct ggml_tensor * bq = nullptr;
|
||||
|
||||
+10
-5
@@ -4,6 +4,7 @@
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
@@ -1625,10 +1626,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
auto * ctx = new llama_sampler_grammar;
|
||||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
std::string trigger_pattern;
|
||||
llama_grammar * grammar = nullptr;
|
||||
// TODO: remove trigger_words support.
|
||||
if (trigger_words != nullptr && num_trigger_words > 0) {
|
||||
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
||||
std::string trigger_pattern("[\\s\\S]*?(");
|
||||
trigger_pattern = "[\\s\\S]*?(";
|
||||
for (size_t i = 0; i < num_trigger_words; ++i) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
if (i > 0) {
|
||||
@@ -1637,15 +1640,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
||||
}
|
||||
trigger_pattern += ")[\\s\\S]*";
|
||||
const auto * trigger_pattern_c = trigger_pattern.c_str();
|
||||
trigger_patterns = &trigger_pattern_c;
|
||||
num_trigger_patterns = 1;
|
||||
|
||||
std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
|
||||
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
|
||||
} else {
|
||||
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
|
||||
}
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||
/* .grammar = */ grammar,
|
||||
};
|
||||
if (!ctx->grammar) {
|
||||
delete ctx;
|
||||
|
||||
@@ -443,6 +443,17 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_AFMOE:
|
||||
regex_exprs = {
|
||||
// Digit handling - uses custom implementation in unicode.cpp
|
||||
// Groups digits with leading 1-2 based on total length modulo 3
|
||||
"\\p{AFMoE_digits}",
|
||||
// CJK and Asian scripts (using direct Unicode literals)
|
||||
"[一-鿿㐀-䶿豈--ゟ゠-ヿ・-゚⼀-เ--ក-က-႟ꩠ-ꩿꧠ-가-ᄀ-ᇿ]+",
|
||||
// Main BPE pattern
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@@ -1993,6 +2004,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "grok-2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "afmoe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_AFMOE;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "minimax-m2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
|
||||
|
||||
@@ -50,6 +50,7 @@ enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
|
||||
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
|
||||
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// MuP scaling: embeddings * sqrt(hidden_size)
|
||||
// mup_enabled = true, hidden_size = 1024, scale = 32.0
|
||||
inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd)));
|
||||
cb(inpL, "inp_embd_scaled", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// dual attention normalization (pre)
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * attn_inp = cur; // save input for gate computation
|
||||
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// compute gate from input
|
||||
ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp);
|
||||
cb(gate, "attn_gate_proj", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Q/K normalization
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
// RoPE only for sliding_attention layers
|
||||
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
|
||||
((il + 1) % hparams.n_no_rope_layer_step) != 0;
|
||||
if (use_rope) {
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
}
|
||||
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
NULL, NULL, // wo will be applied after gating
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
|
||||
// attention gating: attn_out * sigmoid(gate) BEFORE o_proj
|
||||
gate = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate, "attn_gate_sig", il);
|
||||
cur = ggml_mul(ctx0, cur, gate);
|
||||
cb(cur, "attn_gated", il);
|
||||
|
||||
// now apply output projection
|
||||
cur = build_lora_mm(model.layers[il].wo, cur);
|
||||
cb(cur, "attn_o_proj", il);
|
||||
}
|
||||
|
||||
// dual attention normalization (post)
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// dual ffn normalization (pre)
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// MoE or dense FFN
|
||||
if ((uint32_t)il >= hparams.n_layer_dense_lead) {
|
||||
// MoE layer with sigmoid routing, normalization, and scaling
|
||||
ggml_tensor * moe_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU,
|
||||
hparams.expert_weights_norm, // norm_w (route_norm=True)
|
||||
hparams.expert_weights_scale, // scale_w
|
||||
hparams.expert_weights_scale, // w_scale (route_scale=2.826)
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// shared expert
|
||||
if (hparams.n_expert_shared > 0) {
|
||||
ggml_tensor * ffn_shexp = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = moe_out;
|
||||
}
|
||||
} else {
|
||||
// dense layer
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
// dual ffn normalization (post)
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
@@ -57,6 +57,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
int il) const;
|
||||
};
|
||||
|
||||
struct llm_build_afmoe : public llm_graph_context {
|
||||
llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_apertus : public llm_graph_context {
|
||||
llm_build_apertus(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
@@ -729,6 +729,80 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// AFMOE digit handling: splits digits with leading 1-2 based on total length modulo 3
|
||||
static std::vector<size_t> unicode_regex_split_custom_afmoe(const std::string & text, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
bpe_offsets.reserve(offsets.size());
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
const size_t offset_ini = start;
|
||||
const size_t offset_end = start + offset;
|
||||
assert(offset_end <= cpts.size());
|
||||
start = offset_end;
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
auto _add_token = [&] (const size_t end) -> size_t {
|
||||
assert(_prev_end <= end && end <= offset_end);
|
||||
size_t len = end - _prev_end;
|
||||
if (len > 0) {
|
||||
bpe_offsets.push_back(len);
|
||||
}
|
||||
_prev_end = end;
|
||||
return len;
|
||||
};
|
||||
|
||||
for (size_t pos = offset_ini; pos < offset_end; ) {
|
||||
const auto flags = _get_flags(pos);
|
||||
|
||||
// Handle digit sequences with special splitting logic
|
||||
if (flags.is_number) {
|
||||
size_t digit_start = pos;
|
||||
size_t digit_count = 0;
|
||||
|
||||
// Count consecutive digits
|
||||
while (_get_flags(pos).is_number && pos < offset_end) {
|
||||
digit_count++;
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Split based on total length modulo 3
|
||||
size_t remainder = digit_count % 3;
|
||||
size_t current = digit_start;
|
||||
|
||||
// Emit leading 1-2 digits if needed
|
||||
if (remainder > 0) {
|
||||
_add_token(current + remainder);
|
||||
current += remainder;
|
||||
}
|
||||
|
||||
// Emit groups of 3
|
||||
while (current < digit_start + digit_count) {
|
||||
_add_token(current + 3);
|
||||
current += 3;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// For non-digits, just move forward
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Add any remaining content
|
||||
if (_prev_end < offset_end) {
|
||||
_add_token(offset_end);
|
||||
}
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
|
||||
@@ -742,6 +816,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
||||
} else if (regex_expr == "\\p{Han}+") {
|
||||
// K2's first pattern - handle all K2 patterns together
|
||||
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
|
||||
} else if (regex_expr == "\\p{AFMoE_digits}") {
|
||||
// AFMOE digit pattern - use custom implementation for proper splitting
|
||||
bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
|
||||
+262
-8
@@ -175,6 +175,38 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m
|
||||
ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t));
|
||||
}
|
||||
|
||||
// generate a lower triangular matrix
|
||||
static void init_tensor_tril(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(tensor->ne[0] == tensor->ne[1]);
|
||||
|
||||
GGML_TENSOR_LOCALS(int32_t, ne, tensor, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nb, tensor, nb);
|
||||
|
||||
std::vector<float> data_f32(ne0*ne1*ne2*ne3);
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<float> dis(min, max);
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
||||
int64_t idx = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3) / sizeof(float);
|
||||
if (i0 <= i1) {
|
||||
data_f32[idx] = dis(gen);
|
||||
} else {
|
||||
data_f32[idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(tensor, data_f32.data(), 0, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
static std::vector<float> tensor_to_float(const ggml_tensor * t) {
|
||||
std::vector<float> tv;
|
||||
tv.reserve(ggml_nelements(t));
|
||||
@@ -1804,7 +1836,8 @@ struct test_unary : public test_case {
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const bool grad_supported = op == GGML_UNARY_OP_ABS || op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_NEG ||
|
||||
op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU;
|
||||
op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU ||
|
||||
op == GGML_UNARY_OP_EXPM1 || op == GGML_UNARY_OP_SOFTPLUS;
|
||||
|
||||
ggml_tensor * a;
|
||||
if (v & 1) {
|
||||
@@ -2779,7 +2812,7 @@ struct test_bin_bcast : public test_case {
|
||||
const std::array<int, 4> nr;
|
||||
int nf; // number of fused ops, nf == 1 -> single op (no fusion)
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
bool run_whole_graph() override { return nf > 1; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, nr, nf);
|
||||
@@ -4969,17 +5002,19 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
const bool b; // broadcast b matrix (only for use_id)
|
||||
const bool with_bias;
|
||||
const bool with_gate;
|
||||
std::array<int64_t, 2> batch_dims;
|
||||
|
||||
test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
|
||||
bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true)
|
||||
: type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) {
|
||||
bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true,
|
||||
std::array<int64_t, 2> batch_dims = {4, 2})
|
||||
: type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate), batch_dims(batch_dims) {
|
||||
if (use_id) {
|
||||
GGML_ASSERT(n_used <= n_mats);
|
||||
}
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate);
|
||||
return VARS_TO_STR12(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate, batch_dims);
|
||||
}
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
@@ -5005,8 +5040,8 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
if (!use_id) {
|
||||
const int channels = 4;
|
||||
const int samples = 2;
|
||||
const int channels = batch_dims[0];
|
||||
const int samples = batch_dims[1];
|
||||
std::array<int64_t, 4> ne = { k, m, channels, samples };
|
||||
std::array<int64_t, 4> ne0 = { k, n, channels, samples };
|
||||
|
||||
@@ -5029,6 +5064,11 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
}
|
||||
|
||||
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||
|
||||
std::array<int64_t, 4> bias2_ne = { out->ne[0], 1, channels, samples };
|
||||
ggml_tensor * bias2 = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias2_ne.data());
|
||||
out = ggml_add(ctx, out, bias2);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
} else {
|
||||
@@ -5056,6 +5096,11 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
}
|
||||
|
||||
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||
|
||||
std::array<int64_t, 4> scale_ne { 1, out->ne[1], out->ne[2], out->ne[3] };
|
||||
ggml_tensor * scale = ggml_new_tensor(ctx, out->type, 4, scale_ne.data());
|
||||
out = ggml_mul(ctx, out, scale);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
}
|
||||
@@ -5395,6 +5440,7 @@ struct test_pad : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_PAD (with extension)
|
||||
struct test_pad_ext : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
@@ -5802,6 +5848,7 @@ struct test_opt_step_adamw : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_OPT_STEP_SGD
|
||||
struct test_opt_step_sgd : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
@@ -5841,6 +5888,170 @@ struct test_opt_step_sgd : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CUMSUM
|
||||
struct test_cumsum : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR2(type, ne); }
|
||||
|
||||
test_cumsum(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_cumsum(ctx, a);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_XIELU
|
||||
struct test_xielu : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR2(type, ne); }
|
||||
|
||||
test_xielu(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
float alpha_n = 4.0f;
|
||||
float alpha_p = 20.0f;
|
||||
float beta = 0.5f;
|
||||
float eps = 0.0000001f;
|
||||
|
||||
ggml_tensor * out = ggml_xielu(ctx, a, alpha_n, alpha_p, beta, eps);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_TRI
|
||||
struct test_tri : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const ggml_tri_type tri_type;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR3(type, ne, tri_type); }
|
||||
|
||||
test_tri(ggml_tri_type tri_type, ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = { 10, 10, 4, 3 })
|
||||
: type(type), ne(ne), tri_type(tri_type) {
|
||||
GGML_ASSERT(ne[0] == ne[1]);
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_tri(ctx, a, tri_type);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_FILL
|
||||
struct test_fill : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
float c;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR3(type, ne, c); }
|
||||
|
||||
test_fill(float c, ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = { 10, 10, 4, 3 })
|
||||
: type(type), ne(ne), c(c) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * out = ggml_fill(ctx, a, c);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SOLVE_TRI
|
||||
struct test_solve_tri : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_lhs;
|
||||
const std::array<int64_t, 4> ne_rhs;
|
||||
|
||||
std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }
|
||||
|
||||
test_solve_tri(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },
|
||||
std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }
|
||||
)
|
||||
: type(type), ne_lhs(ne_lhs), ne_rhs(ne_rhs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne_lhs[0], ne_lhs[1], ne_lhs[2], ne_lhs[3]);
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne_rhs[0], ne_rhs[1], ne_rhs[2], ne_rhs[3]);
|
||||
ggml_set_param(b);
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
ggml_tensor * out = ggml_solve_tri(ctx, a, b, true, true, false);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (strcmp(t->name, "a") == 0) {
|
||||
// note: avoid zeros in the diagonal
|
||||
init_tensor_tril(t, 0.1, 1.0f);
|
||||
} else {
|
||||
init_tensor_uniform(t, -1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
enum llm_norm_type {
|
||||
LLM_NORM,
|
||||
LLM_NORM_RMS,
|
||||
@@ -6282,6 +6493,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
for (int v : {0, 1}) {
|
||||
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
|
||||
if (op == GGML_UNARY_OP_XIELU) {
|
||||
continue; // need extra params, separate test
|
||||
}
|
||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
|
||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
|
||||
}
|
||||
@@ -7290,8 +7504,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||
}
|
||||
|
||||
@@ -7340,6 +7559,39 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 512, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1023, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2047, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 }));
|
||||
|
||||
test_cases.emplace_back(new test_xielu());
|
||||
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER));
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER));
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG));
|
||||
|
||||
test_cases.emplace_back(new test_fill(0.0f));
|
||||
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
|
||||
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
|
||||
|
||||
test_cases.emplace_back(new test_solve_tri());
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 17, 17, 2, 4 }, { 9, 17, 2, 4 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
|
||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
|
||||
|
||||
for (bool v : {false, true}) {
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
|
||||
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
|
||||
@@ -7418,6 +7670,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
|
||||
use_id, 16, 8, b, with_bias, with_gate));
|
||||
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
|
||||
use_id, 16, 8, b, with_bias, with_gate, {1, 1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+5
-12
@@ -224,7 +224,6 @@ static void clip_log_callback_default(enum ggml_log_level level, const char * te
|
||||
}
|
||||
|
||||
struct clip_logger_state {
|
||||
ggml_log_level verbosity_thold;
|
||||
ggml_log_callback log_callback;
|
||||
void * log_callback_user_data;
|
||||
};
|
||||
@@ -258,17 +257,11 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
#define LOG_TMPL(level, ...) \
|
||||
do { \
|
||||
if ((level) >= g_logger_state.verbosity_thold) { \
|
||||
clip_log_internal((level), __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
|
||||
#define LOG_INF(...) clip_log_internal(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) clip_log_internal(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) clip_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define LOG_DBG(...) clip_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_CNT(...) clip_log_internal(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
|
||||
|
||||
//
|
||||
// cpp wrappers
|
||||
|
||||
+1
-3
@@ -24,8 +24,7 @@
|
||||
#include <array>
|
||||
#include <functional>
|
||||
|
||||
// TODO: allow to pass callback from user code
|
||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||
struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
|
||||
|
||||
enum ffn_op_type {
|
||||
FFN_GELU,
|
||||
@@ -3507,7 +3506,6 @@ struct clip_model_loader {
|
||||
};
|
||||
|
||||
struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
|
||||
g_logger_state.verbosity_thold = ctx_params.verbosity;
|
||||
clip_ctx * ctx_vision = nullptr;
|
||||
clip_ctx * ctx_audio = nullptr;
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ enum clip_flash_attn_type {
|
||||
|
||||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
enum ggml_log_level verbosity;
|
||||
enum clip_flash_attn_type flash_attn_type;
|
||||
int image_min_tokens;
|
||||
int image_max_tokens;
|
||||
|
||||
@@ -135,7 +135,6 @@ struct mtmd_cli_context {
|
||||
mparams.use_gpu = params.mmproj_use_gpu;
|
||||
mparams.print_timings = true;
|
||||
mparams.n_threads = params.cpuparams.n_threads;
|
||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params.flash_attn_type;
|
||||
mparams.image_min_tokens = params.image_min_tokens;
|
||||
mparams.image_max_tokens = params.image_max_tokens;
|
||||
@@ -277,6 +276,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
common_init();
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
|
||||
if (params.mmproj.path.empty()) {
|
||||
show_additional_info(argc, argv);
|
||||
@@ -285,7 +285,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
mtmd_cli_context ctx(params);
|
||||
LOG("%s: loading model: %s\n", __func__, params.model.path.c_str());
|
||||
LOG_INF("%s: loading model: %s\n", __func__, params.model.path.c_str());
|
||||
|
||||
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
||||
|
||||
|
||||
@@ -32,8 +32,65 @@
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb/stb_image.h"
|
||||
|
||||
#define LOG_INF(...) fprintf(stdout, __VA_ARGS__)
|
||||
#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__)
|
||||
//
|
||||
// internal logging functions
|
||||
//
|
||||
|
||||
struct mtmd_helper_logger {
|
||||
ggml_log_callback default_callback = [](ggml_log_level level, const char * text, void * user_data) {
|
||||
(void) level;
|
||||
(void) user_data;
|
||||
fputs(text, stderr);
|
||||
fflush(stderr);
|
||||
};
|
||||
|
||||
ggml_log_callback log_callback = default_callback;
|
||||
void * log_callback_user_data;
|
||||
|
||||
void log_v(enum ggml_log_level level, const char * format, va_list args) {
|
||||
if (format == NULL) {
|
||||
return;
|
||||
}
|
||||
va_list args_copy;
|
||||
va_copy(args_copy, args);
|
||||
char buffer[128];
|
||||
int len = vsnprintf(buffer, 128, format, args);
|
||||
if (len < 128) {
|
||||
log_callback(level, buffer, log_callback_user_data);
|
||||
} else {
|
||||
char * buffer2 = (char *) calloc(len + 1, sizeof(char));
|
||||
vsnprintf(buffer2, len + 1, format, args_copy);
|
||||
buffer2[len] = 0;
|
||||
log_callback(level, buffer2, log_callback_user_data);
|
||||
free(buffer2);
|
||||
}
|
||||
va_end(args_copy);
|
||||
}
|
||||
|
||||
void log(enum ggml_log_level level, const char * format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
log_v(level, format, args);
|
||||
va_end(args);
|
||||
}
|
||||
} g_logger;
|
||||
|
||||
#define LOG_INF(...) g_logger.log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) g_logger.log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) g_logger.log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
|
||||
void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||
if (log_callback == nullptr) {
|
||||
log_callback = g_logger.default_callback;
|
||||
}
|
||||
g_logger.log_callback = log_callback;
|
||||
g_logger.log_callback_user_data = user_data;
|
||||
mtmd_log_set(log_callback, user_data);
|
||||
}
|
||||
|
||||
//
|
||||
// helper functions
|
||||
//
|
||||
|
||||
size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
|
||||
size_t n_tokens = 0;
|
||||
@@ -325,7 +382,7 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
|
||||
llama_pos * new_n_past) {
|
||||
size_t n_chunks = mtmd_input_chunks_size(chunks);
|
||||
if (n_chunks == 0) {
|
||||
LOG_ERR("no chunks to eval\n");
|
||||
LOG_WRN("no chunks to eval\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,11 @@ extern "C" {
|
||||
// BREAKING CHANGES are expected.
|
||||
//
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
// Note: this also call mtmd_log_set() internally
|
||||
MTMD_API void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
// helper function to construct a mtmd_bitmap from a file
|
||||
// it calls mtmd_helper_bitmap_init_from_buf() internally
|
||||
// returns nullptr on failure
|
||||
|
||||
+5
-2
@@ -105,7 +105,6 @@ mtmd_context_params mtmd_context_params_default() {
|
||||
/* use_gpu */ true,
|
||||
/* print_timings */ true,
|
||||
/* n_threads */ 4,
|
||||
/* verbosity */ GGML_LOG_LEVEL_INFO,
|
||||
/* image_marker */ MTMD_DEFAULT_IMAGE_MARKER,
|
||||
/* media_marker */ mtmd_default_marker(),
|
||||
/* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO,
|
||||
@@ -175,7 +174,6 @@ struct mtmd_context {
|
||||
|
||||
clip_context_params ctx_clip_params {
|
||||
/* use_gpu */ ctx_params.use_gpu,
|
||||
/* verbosity */ ctx_params.verbosity,
|
||||
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
|
||||
/* image_min_tokens */ ctx_params.image_min_tokens,
|
||||
/* image_max_tokens */ ctx_params.image_max_tokens,
|
||||
@@ -1096,3 +1094,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
void mtmd_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||
g_logger_state.log_callback = log_callback ? log_callback : clip_log_callback_default;
|
||||
g_logger_state.log_callback_user_data = user_data;
|
||||
}
|
||||
|
||||
+4
-1
@@ -79,7 +79,6 @@ struct mtmd_context_params {
|
||||
bool use_gpu;
|
||||
bool print_timings;
|
||||
int n_threads;
|
||||
enum ggml_log_level verbosity;
|
||||
const char * image_marker; // deprecated, use media_marker instead
|
||||
const char * media_marker;
|
||||
enum llama_flash_attn_type flash_attn_type;
|
||||
@@ -215,6 +214,10 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
||||
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
|
||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
MTMD_API void mtmd_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
/////////////////////////////////////////
|
||||
|
||||
// test function, to be used in test-mtmd-c-api.c
|
||||
|
||||
Binary file not shown.
+85
-79
@@ -1686,14 +1686,13 @@ struct server_slot {
|
||||
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
|
||||
}
|
||||
|
||||
void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
||||
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
||||
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
||||
if (!res) {
|
||||
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||
prompt.tokens.clear();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<common_adapter_lora_info> lora;
|
||||
@@ -2339,7 +2338,6 @@ struct server_context {
|
||||
|
||||
llama_batch batch {};
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
@@ -2454,11 +2452,12 @@ struct server_context {
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
@@ -2701,7 +2700,10 @@ struct server_context {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
ret->prompt_save(*prompt_cache);
|
||||
ret->prompt_load(*prompt_cache, task.tokens);
|
||||
|
||||
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
|
||||
clear_slot(*ret);
|
||||
}
|
||||
|
||||
prompt_cache->update();
|
||||
|
||||
@@ -2712,12 +2714,21 @@ struct server_context {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// return true if at least one slot has been purged
|
||||
void clear_slot(server_slot & slot) const {
|
||||
GGML_ASSERT(!slot.is_processing());
|
||||
|
||||
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
}
|
||||
|
||||
// return true if at least one slot has been cleared
|
||||
// TODO: improve logic
|
||||
// - smarter decision which slot to purge (LRU or longest prompt?)
|
||||
// - smarter decision which slot to clear (LRU or longest prompt?)
|
||||
// - move slot to level 2 cache instead of removing?
|
||||
// - instead of purging, try to store and resume later?
|
||||
bool try_purge_idle_slots() {
|
||||
bool try_clear_idle_slots() {
|
||||
bool res = false;
|
||||
|
||||
if (!params_base.kv_unified) {
|
||||
@@ -2732,12 +2743,11 @@ struct server_context {
|
||||
if (slot.prompt.n_tokens() > 0) {
|
||||
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
clear_slot(slot);
|
||||
|
||||
res = true;
|
||||
|
||||
// purge slots one by one
|
||||
// clear slots one by one
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -2847,14 +2857,6 @@ struct server_context {
|
||||
return true;
|
||||
}
|
||||
|
||||
void kv_cache_clear() {
|
||||
SRV_DBG("%s", "clearing KV cache\n");
|
||||
|
||||
// clear the entire KV cache
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
clean_kv_cache = false;
|
||||
}
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = result.text_to_send;
|
||||
@@ -3442,8 +3444,8 @@ struct server_context {
|
||||
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->prompt.tokens.size();
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
|
||||
slot->prompt.tokens.clear();
|
||||
|
||||
clear_slot(*slot);
|
||||
|
||||
auto res = std::make_unique<server_task_result_slot_erase>();
|
||||
res->id = task.id;
|
||||
@@ -3476,9 +3478,6 @@ struct server_context {
|
||||
|
||||
if (all_idle) {
|
||||
SRV_INF("%s", "all slots are idle\n");
|
||||
if (clean_kv_cache) {
|
||||
kv_cache_clear();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -3591,13 +3590,13 @@ struct server_context {
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.is_processing()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check if we can batch this slot with the previous one
|
||||
if (slot.is_processing()) {
|
||||
if (!slot_batched) {
|
||||
slot_batched = &slot;
|
||||
} else if (!slot_batched->can_batch_with(slot)) {
|
||||
continue;
|
||||
}
|
||||
if (slot_batched && !slot_batched->can_batch_with(slot)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// this slot still has a prompt to be processed
|
||||
@@ -3872,12 +3871,11 @@ struct server_context {
|
||||
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
|
||||
clear_slot(slot);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_prompt_tokens_cache = 0;
|
||||
|
||||
slot.prompt.tokens.clear();
|
||||
}
|
||||
|
||||
// check if we should process the image
|
||||
@@ -4028,6 +4026,10 @@ struct server_context {
|
||||
}
|
||||
}
|
||||
|
||||
if (!slot_batched) {
|
||||
slot_batched = &slot;
|
||||
}
|
||||
|
||||
if (batch.n_tokens >= n_batch) {
|
||||
break;
|
||||
}
|
||||
@@ -4103,6 +4105,10 @@ struct server_context {
|
||||
if (slot.is_processing()) {
|
||||
send_error(slot, err);
|
||||
slot.release();
|
||||
|
||||
// note: it's complicated to keep track of how much of the current batch has been
|
||||
// processed before the error occurred, so we simply clear the entire context
|
||||
clear_slot(slot);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4111,7 +4117,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
if (!try_purge_idle_slots()) {
|
||||
if (!try_clear_idle_slots()) {
|
||||
n_batch /= 2;
|
||||
}
|
||||
|
||||
@@ -4431,7 +4437,7 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
|
||||
SRV_DBG("response: %s\n", res.body.c_str());
|
||||
}
|
||||
|
||||
static void res_error(httplib::Response & res, const json & error_data) {
|
||||
static void res_err(httplib::Response & res, const json & error_data) {
|
||||
json final_response {{"error", error_data}};
|
||||
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
@@ -4524,7 +4530,7 @@ int main(int argc, char ** argv) {
|
||||
try {
|
||||
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
||||
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
|
||||
res_error(res, formatted_error);
|
||||
res_err(res, formatted_error);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
|
||||
}
|
||||
@@ -4532,9 +4538,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 404) {
|
||||
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
}
|
||||
// for other error codes, we skip processing here because it's already done by res_error()
|
||||
// for other error codes, we skip processing here because it's already done by res_err()
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
@@ -4591,7 +4597,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
res_err(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
|
||||
LOG_WRN("Unauthorized: Invalid API Key\n");
|
||||
|
||||
@@ -4609,7 +4615,7 @@ int main(int argc, char ** argv) {
|
||||
// allow the models endpoint to be accessed during loading
|
||||
return true;
|
||||
} else {
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -4648,7 +4654,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!params.endpoint_slots) {
|
||||
res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4666,7 +4672,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4677,7 +4683,7 @@ int main(int argc, char ** argv) {
|
||||
// optionally return "fail_on_no_slot" error
|
||||
if (req.has_param("fail_on_no_slot")) {
|
||||
if (res_task->n_idle_slots == 0) {
|
||||
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
||||
res_err(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -4687,7 +4693,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!params.endpoint_metrics) {
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4705,7 +4711,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4790,7 +4796,7 @@ int main(int argc, char ** argv) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
@@ -4811,7 +4817,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4822,7 +4828,7 @@ int main(int argc, char ** argv) {
|
||||
json request_data = json::parse(req.body);
|
||||
std::string filename = request_data.at("filename");
|
||||
if (!fs_validate_filename(filename)) {
|
||||
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
std::string filepath = params.slot_save_path + filename;
|
||||
@@ -4843,7 +4849,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4866,7 +4872,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4876,7 +4882,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
if (params.slot_save_path.empty()) {
|
||||
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4886,7 +4892,7 @@ int main(int argc, char ** argv) {
|
||||
try {
|
||||
id_slot = std::stoi(id_slot_str);
|
||||
} catch (const std::exception &) {
|
||||
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4899,7 +4905,7 @@ int main(int argc, char ** argv) {
|
||||
} else if (action == "erase") {
|
||||
handle_slots_erase(req, res, id_slot);
|
||||
} else {
|
||||
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4947,7 +4953,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!ctx_server.params_base.endpoint_props) {
|
||||
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5044,7 +5050,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
rd->post_tasks(std::move(tasks));
|
||||
} catch (const std::exception & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5056,7 +5062,7 @@ int main(int argc, char ** argv) {
|
||||
if (all_results.is_terminated) {
|
||||
return; // connection is closed
|
||||
} else if (all_results.error) {
|
||||
res_error(res, all_results.error->to_json());
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
} else {
|
||||
json arr = json::array();
|
||||
@@ -5076,7 +5082,7 @@ int main(int argc, char ** argv) {
|
||||
if (first_result == nullptr) {
|
||||
return; // connection is closed
|
||||
} else if (first_result->is_error()) {
|
||||
res_error(res, first_result->to_json());
|
||||
res_err(res, first_result->to_json());
|
||||
return;
|
||||
} else {
|
||||
GGML_ASSERT(
|
||||
@@ -5183,7 +5189,7 @@ int main(int argc, char ** argv) {
|
||||
err += "middle token is missing. ";
|
||||
}
|
||||
if (!err.empty()) {
|
||||
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5192,20 +5198,20 @@ int main(int argc, char ** argv) {
|
||||
// validate input
|
||||
if (data.contains("prompt") && !data.at("prompt").is_string()) {
|
||||
// prompt is optional
|
||||
res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (!data.contains("input_prefix")) {
|
||||
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (!data.contains("input_suffix")) {
|
||||
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||
}
|
||||
|
||||
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||
// input_extra is optional
|
||||
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5213,12 +5219,12 @@ int main(int argc, char ** argv) {
|
||||
for (const auto & chunk : input_extra) {
|
||||
// { "text": string, "filename": string }
|
||||
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
||||
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
// filename is optional
|
||||
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
||||
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -5380,12 +5386,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
|
||||
if (!ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5399,7 +5405,7 @@ int main(int argc, char ** argv) {
|
||||
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
||||
prompt = body.at("content");
|
||||
} else {
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5409,7 +5415,7 @@ int main(int argc, char ** argv) {
|
||||
if (format == "base64") {
|
||||
use_base64 = true;
|
||||
} else if (format != "float") {
|
||||
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -5418,7 +5424,7 @@ int main(int argc, char ** argv) {
|
||||
for (const auto & tokens : tokenized_prompts) {
|
||||
// this check is necessary for models that do not add BOS token to the input
|
||||
if (tokens.empty()) {
|
||||
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -5459,7 +5465,7 @@ int main(int argc, char ** argv) {
|
||||
if (all_results.is_terminated) {
|
||||
return; // connection is closed
|
||||
} else if (all_results.error) {
|
||||
res_error(res, all_results.error->to_json());
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
} else {
|
||||
for (auto & res : all_results.results) {
|
||||
@@ -5485,7 +5491,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
|
||||
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_err(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5500,18 +5506,18 @@ int main(int argc, char ** argv) {
|
||||
if (body.count("query") == 1) {
|
||||
query = body.at("query");
|
||||
if (!query.is_string()) {
|
||||
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> documents = json_value(body, "documents",
|
||||
json_value(body, "texts", std::vector<std::string>()));
|
||||
if (documents.empty()) {
|
||||
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5541,7 +5547,7 @@ int main(int argc, char ** argv) {
|
||||
if (all_results.is_terminated) {
|
||||
return; // connection is closed
|
||||
} else if (all_results.error) {
|
||||
res_error(res, all_results.error->to_json());
|
||||
res_err(res, all_results.error->to_json());
|
||||
return;
|
||||
} else {
|
||||
for (auto & res : all_results.results) {
|
||||
@@ -5594,7 +5600,7 @@ int main(int argc, char ** argv) {
|
||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
const json body = json::parse(req.body);
|
||||
if (!body.is_array()) {
|
||||
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||
res_err(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -5612,7 +5618,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
res_err(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
+19
-29
@@ -1,6 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { RemoveButton } from '$lib/components/app';
|
||||
import { formatFileSize, getFileTypeLabel, getPreviewText } from '$lib/utils/file-preview';
|
||||
import { FileTypeCategory, MimeTypeText } from '$lib/enums/files';
|
||||
|
||||
@@ -66,17 +65,15 @@
|
||||
</button>
|
||||
{:else}
|
||||
<!-- Non-readonly mode (ChatForm) -->
|
||||
<div class="relative rounded-lg border border-border bg-muted p-3 {className} w-64">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="absolute top-2 right-2 h-6 w-6 bg-white/20 p-0 hover:bg-white/30"
|
||||
onclick={() => onRemove?.(id)}
|
||||
aria-label="Remove file"
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
<button
|
||||
class="group relative rounded-lg border border-border bg-muted p-3 {className} {textContent
|
||||
? 'max-h-24 max-w-72'
|
||||
: 'max-w-36'} cursor-pointer text-left"
|
||||
onclick={onClick}
|
||||
>
|
||||
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
|
||||
<RemoveButton {id} {onRemove} />
|
||||
</div>
|
||||
|
||||
<div class="pr-8">
|
||||
<span class="mb-3 block truncate text-sm font-medium text-foreground">{name}</span>
|
||||
@@ -85,7 +82,7 @@
|
||||
<div class="relative">
|
||||
<div
|
||||
class="overflow-hidden font-mono text-xs leading-relaxed break-words whitespace-pre-wrap text-muted-foreground"
|
||||
style="max-height: 3.6em; line-height: 1.2em;"
|
||||
style="max-height: 3rem; line-height: 1.2em;"
|
||||
>
|
||||
{getPreviewText(textContent)}
|
||||
</div>
|
||||
@@ -98,11 +95,11 @@
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
{/if}
|
||||
{:else}
|
||||
<button
|
||||
class="flex items-center gap-2 gap-3 rounded-lg border border-border bg-muted p-3 {className}"
|
||||
class="group flex items-center gap-3 rounded-lg border border-border bg-muted p-3 {className} relative"
|
||||
onclick={onClick}
|
||||
>
|
||||
<div
|
||||
@@ -112,7 +109,9 @@
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<span class="max-w-36 truncate text-sm font-medium text-foreground md:max-w-72">
|
||||
<span
|
||||
class="max-w-24 truncate text-sm font-medium text-foreground group-hover:pr-6 md:max-w-32"
|
||||
>
|
||||
{name}
|
||||
</span>
|
||||
|
||||
@@ -122,18 +121,9 @@
|
||||
</div>
|
||||
|
||||
{#if !readonly}
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 p-0"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(id);
|
||||
}}
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
<div class="absolute top-2 right-2 opacity-0 transition-opacity group-hover:opacity-100">
|
||||
<RemoveButton {id} {onRemove} />
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
+5
-14
@@ -1,6 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { RemoveButton } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
@@ -26,12 +25,12 @@
|
||||
class: className = '',
|
||||
// Default to small size for form previews
|
||||
width = 'w-auto',
|
||||
height = 'h-24',
|
||||
height = 'h-16',
|
||||
imageClass = ''
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative overflow-hidden rounded-lg border border-border bg-muted {className}">
|
||||
<div class="group relative overflow-hidden rounded-lg border border-border bg-muted {className}">
|
||||
{#if onClick}
|
||||
<button
|
||||
type="button"
|
||||
@@ -55,17 +54,9 @@
|
||||
|
||||
{#if !readonly}
|
||||
<div
|
||||
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity hover:opacity-100"
|
||||
class="absolute top-1 right-1 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
|
||||
>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 bg-white/20 p-0 text-white hover:bg-white/30"
|
||||
onclick={() => onRemove?.(id)}
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
<RemoveButton {id} {onRemove} class="text-white" />
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
+1
-1
@@ -153,7 +153,7 @@
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Content class="grid max-h-[90vh] max-w-5xl overflow-hidden !p-10 sm:w-auto sm:max-w-6xl">
|
||||
<Dialog.Header class="flex-shrink-0">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center justify-between gap-6">
|
||||
<div class="flex items-center gap-3">
|
||||
{#if IconComponent}
|
||||
<IconComponent class="h-5 w-5 text-muted-foreground" />
|
||||
|
||||
+142
-59
@@ -1,11 +1,16 @@
|
||||
<script lang="ts">
|
||||
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
|
||||
import { FileTypeCategory } from '$lib/enums/files';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
|
||||
import ChatAttachmentsViewAllDialog from './ChatAttachmentsViewAllDialog.svelte';
|
||||
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
style?: string;
|
||||
// For ChatMessage - stored attachments
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
@@ -16,10 +21,13 @@
|
||||
imageClass?: string;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
// Limit display to single row with "+ X more" button
|
||||
limitToSingleRow?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
style = '',
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
onFileRemove,
|
||||
@@ -27,36 +35,23 @@
|
||||
// Default to small size for form previews
|
||||
imageClass = '',
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto'
|
||||
imageWidth = 'w-auto',
|
||||
limitToSingleRow = false
|
||||
}: Props = $props();
|
||||
|
||||
let displayItems = $derived(getDisplayItems());
|
||||
|
||||
// Preview dialog state
|
||||
let canScrollLeft = $state(false);
|
||||
let canScrollRight = $state(false);
|
||||
let isScrollable = $state(false);
|
||||
let previewDialogOpen = $state(false);
|
||||
let previewItem = $state<{
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
preview?: string;
|
||||
name?: string;
|
||||
type?: string;
|
||||
size?: number;
|
||||
textContent?: string;
|
||||
} | null>(null);
|
||||
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
|
||||
let scrollContainer: HTMLDivElement | undefined = $state();
|
||||
let showViewAll = $derived(limitToSingleRow && displayItems.length > 0 && isScrollable);
|
||||
let viewAllDialogOpen = $state(false);
|
||||
|
||||
function getDisplayItems() {
|
||||
const items: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
size?: number;
|
||||
preview?: string;
|
||||
type: string;
|
||||
isImage: boolean;
|
||||
uploadedFile?: ChatUploadedFile;
|
||||
attachment?: DatabaseMessageExtra;
|
||||
attachmentIndex?: number;
|
||||
textContent?: string;
|
||||
}> = [];
|
||||
function getDisplayItems(): ChatAttachmentDisplayItem[] {
|
||||
const items: ChatAttachmentDisplayItem[] = [];
|
||||
|
||||
// Add uploaded files (ChatForm)
|
||||
for (const file of uploadedFiles) {
|
||||
@@ -127,14 +122,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
return items;
|
||||
return items.reverse();
|
||||
}
|
||||
|
||||
function openPreview(item: (typeof displayItems)[0], event?: Event) {
|
||||
if (event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
}
|
||||
function openPreview(item: ChatAttachmentDisplayItem, event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
previewItem = {
|
||||
uploadedFile: item.uploadedFile,
|
||||
@@ -147,38 +140,118 @@
|
||||
};
|
||||
previewDialogOpen = true;
|
||||
}
|
||||
|
||||
function scrollLeft(event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * -0.67, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function scrollRight(event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * 0.67, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function updateScrollButtons() {
|
||||
if (!scrollContainer) return;
|
||||
|
||||
const { scrollLeft, scrollWidth, clientWidth } = scrollContainer;
|
||||
|
||||
canScrollLeft = scrollLeft > 0;
|
||||
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1;
|
||||
isScrollable = scrollWidth > clientWidth;
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (scrollContainer && displayItems.length) {
|
||||
scrollContainer.scrollLeft = 0;
|
||||
|
||||
setTimeout(() => {
|
||||
updateScrollButtons();
|
||||
}, 0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if displayItems.length > 0}
|
||||
<div class="flex flex-wrap items-start {readonly ? 'justify-end' : ''} gap-3 {className}">
|
||||
{#each displayItems as item (item.id)}
|
||||
{#if item.isImage && item.preview}
|
||||
<ChatAttachmentImagePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatAttachmentFilePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
type={item.type}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
<div class={className} {style}>
|
||||
<div class="relative">
|
||||
<button
|
||||
class="absolute top-1/2 left-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollLeft
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollLeft}
|
||||
aria-label="Scroll left"
|
||||
>
|
||||
<ChevronLeft class="h-4 w-4" />
|
||||
</button>
|
||||
|
||||
<div
|
||||
class="scrollbar-hide flex items-start gap-3 overflow-x-auto"
|
||||
bind:this={scrollContainer}
|
||||
onscroll={updateScrollButtons}
|
||||
>
|
||||
{#each displayItems as item (item.id)}
|
||||
{#if item.isImage && item.preview}
|
||||
<ChatAttachmentImagePreview
|
||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatAttachmentFilePreview
|
||||
class="flex-shrink-0 cursor-pointer {limitToSingleRow ? 'first:ml-4 last:mr-4' : ''}"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
type={item.type}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<button
|
||||
class="absolute top-1/2 right-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollRight
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollRight}
|
||||
aria-label="Scroll right"
|
||||
>
|
||||
<ChevronRight class="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{#if showViewAll}
|
||||
<div class="mt-2 -mr-2 flex justify-end px-4">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 text-xs text-muted-foreground hover:text-foreground"
|
||||
onclick={() => (viewAllDialogOpen = true)}
|
||||
>
|
||||
View all
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -194,3 +267,13 @@
|
||||
textContent={previewItem.textContent}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<ChatAttachmentsViewAllDialog
|
||||
bind:open={viewAllDialogOpen}
|
||||
{uploadedFiles}
|
||||
{attachments}
|
||||
{readonly}
|
||||
{onFileRemove}
|
||||
imageHeight="h-64"
|
||||
{imageClass}
|
||||
/>
|
||||
|
||||
+203
@@ -0,0 +1,203 @@
|
||||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ChatAttachmentImagePreview, ChatAttachmentFilePreview } from '$lib/components/app';
|
||||
import { FileTypeCategory } from '$lib/enums/files';
|
||||
import { getFileTypeCategory } from '$lib/utils/file-type';
|
||||
import ChatAttachmentPreviewDialog from './ChatAttachmentPreviewDialog.svelte';
|
||||
import type { ChatAttachmentDisplayItem, ChatAttachmentPreviewItem } from '$lib/types/chat';
|
||||
|
||||
interface Props {
|
||||
open?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
readonly?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
imageHeight?: string;
|
||||
imageWidth?: string;
|
||||
imageClass?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
open = $bindable(false),
|
||||
uploadedFiles = [],
|
||||
attachments = [],
|
||||
readonly = false,
|
||||
onFileRemove,
|
||||
imageHeight = 'h-24',
|
||||
imageWidth = 'w-auto',
|
||||
imageClass = ''
|
||||
}: Props = $props();
|
||||
|
||||
let previewDialogOpen = $state(false);
|
||||
let previewItem = $state<ChatAttachmentPreviewItem | null>(null);
|
||||
|
||||
let displayItems = $derived(getDisplayItems());
|
||||
let imageItems = $derived(displayItems.filter((item) => item.isImage));
|
||||
let fileItems = $derived(displayItems.filter((item) => !item.isImage));
|
||||
|
||||
function getDisplayItems(): ChatAttachmentDisplayItem[] {
|
||||
const items: ChatAttachmentDisplayItem[] = [];
|
||||
|
||||
for (const file of uploadedFiles) {
|
||||
items.push({
|
||||
id: file.id,
|
||||
name: file.name,
|
||||
size: file.size,
|
||||
preview: file.preview,
|
||||
type: file.type,
|
||||
isImage: getFileTypeCategory(file.type) === FileTypeCategory.IMAGE,
|
||||
uploadedFile: file,
|
||||
textContent: file.textContent
|
||||
});
|
||||
}
|
||||
|
||||
for (const [index, attachment] of attachments.entries()) {
|
||||
if (attachment.type === 'imageFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
preview: attachment.base64Url,
|
||||
type: 'image',
|
||||
isImage: true,
|
||||
attachment,
|
||||
attachmentIndex: index
|
||||
});
|
||||
} else if (attachment.type === 'textFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: 'text',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
} else if (attachment.type === 'context') {
|
||||
// Legacy format from old webui - treat as text file
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: 'text',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
} else if (attachment.type === 'audioFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: attachment.mimeType || 'audio',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index
|
||||
});
|
||||
} else if (attachment.type === 'pdfFile') {
|
||||
items.push({
|
||||
id: `attachment-${index}`,
|
||||
name: attachment.name,
|
||||
type: 'application/pdf',
|
||||
isImage: false,
|
||||
attachment,
|
||||
attachmentIndex: index,
|
||||
textContent: attachment.content
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return items.reverse();
|
||||
}
|
||||
|
||||
function openPreview(item: (typeof displayItems)[0], event?: Event) {
|
||||
if (event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
}
|
||||
|
||||
previewItem = {
|
||||
uploadedFile: item.uploadedFile,
|
||||
attachment: item.attachment,
|
||||
preview: item.preview,
|
||||
name: item.name,
|
||||
type: item.type,
|
||||
size: item.size,
|
||||
textContent: item.textContent
|
||||
};
|
||||
previewDialogOpen = true;
|
||||
}
|
||||
</script>
|
||||
|
||||
<Dialog.Root bind:open>
|
||||
<Dialog.Portal>
|
||||
<Dialog.Overlay />
|
||||
|
||||
<Dialog.Content class="flex !max-h-[90vh] !max-w-6xl flex-col">
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>All Attachments ({displayItems.length})</Dialog.Title>
|
||||
<Dialog.Description class="text-sm text-muted-foreground">
|
||||
View and manage all attached files
|
||||
</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
<div class="min-h-0 flex-1 space-y-6 overflow-y-auto px-1">
|
||||
{#if fileItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Files ({fileItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each fileItems as item (item.id)}
|
||||
<ChatAttachmentFilePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
type={item.type}
|
||||
size={item.size}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
textContent={item.textContent}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if imageItems.length > 0}
|
||||
<div>
|
||||
<h3 class="mb-3 text-sm font-medium text-foreground">Images ({imageItems.length})</h3>
|
||||
<div class="flex flex-wrap items-start gap-3">
|
||||
{#each imageItems as item (item.id)}
|
||||
{#if item.preview}
|
||||
<ChatAttachmentImagePreview
|
||||
class="cursor-pointer"
|
||||
id={item.id}
|
||||
name={item.name}
|
||||
preview={item.preview}
|
||||
{readonly}
|
||||
onRemove={onFileRemove}
|
||||
height={imageHeight}
|
||||
width={imageWidth}
|
||||
{imageClass}
|
||||
onClick={(event) => openPreview(item, event)}
|
||||
/>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
</Dialog.Root>
|
||||
|
||||
{#if previewItem}
|
||||
<ChatAttachmentPreviewDialog
|
||||
bind:open={previewDialogOpen}
|
||||
uploadedFile={previewItem.uploadedFile}
|
||||
attachment={previewItem.attachment}
|
||||
preview={previewItem.preview}
|
||||
name={previewItem.name}
|
||||
type={previewItem.type}
|
||||
size={previewItem.size}
|
||||
textContent={previewItem.textContent}
|
||||
/>
|
||||
{/if}
|
||||
@@ -232,7 +232,13 @@
|
||||
onsubmit={handleSubmit}
|
||||
class="{INPUT_CLASSES} border-radius-bottom-none mx-auto max-w-[48rem] overflow-hidden rounded-3xl backdrop-blur-md {className}"
|
||||
>
|
||||
<ChatAttachmentsList bind:uploadedFiles {onFileRemove} class="mb-3 px-5 pt-5" />
|
||||
<ChatAttachmentsList
|
||||
bind:uploadedFiles
|
||||
{onFileRemove}
|
||||
limitToSingleRow
|
||||
class="py-5"
|
||||
style="scroll-padding: 1rem;"
|
||||
/>
|
||||
|
||||
<div
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl px-5 py-3 shadow-sm transition-all focus-within:shadow-md"
|
||||
|
||||
+1
-7
@@ -72,12 +72,6 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (isOpen) {
|
||||
updateMenuPosition();
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSelect(value: string | undefined) {
|
||||
if (!value) return;
|
||||
|
||||
@@ -259,7 +253,7 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
<svelte:window onresize={handleResize} onscroll={handleScroll} />
|
||||
<svelte:window onresize={handleResize} />
|
||||
|
||||
<svelte:document onpointerdown={handlePointerDown} onkeydown={handleKeydown} />
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { getDeletionInfo } from '$lib/stores/chat.svelte';
|
||||
import { copyToClipboard } from '$lib/utils/copy';
|
||||
import { isIMEComposing } from '$lib/utils/is-ime-composing';
|
||||
import type { ApiChatCompletionToolCall } from '$lib/types/api';
|
||||
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
||||
import ChatMessageUser from './ChatMessageUser.svelte';
|
||||
|
||||
@@ -54,6 +55,29 @@
|
||||
return null;
|
||||
});
|
||||
|
||||
let toolCallContent = $derived.by((): ApiChatCompletionToolCall[] | string | null => {
|
||||
if (message.role === 'assistant') {
|
||||
const trimmedToolCalls = message.toolCalls?.trim();
|
||||
|
||||
if (!trimmedToolCalls) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(trimmedToolCalls);
|
||||
|
||||
if (Array.isArray(parsed)) {
|
||||
return parsed as ApiChatCompletionToolCall[];
|
||||
}
|
||||
} catch {
|
||||
// Harmony-only path: fall back to the raw string so issues surface visibly.
|
||||
}
|
||||
|
||||
return trimmedToolCalls;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
function handleCancelEdit() {
|
||||
isEditing = false;
|
||||
editedContent = message.content;
|
||||
@@ -171,5 +195,6 @@
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
{thinkingContent}
|
||||
{toolCallContent}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
+117
-2
@@ -11,7 +11,8 @@
|
||||
Gauge,
|
||||
Clock,
|
||||
WholeWord,
|
||||
ChartNoAxesColumn
|
||||
ChartNoAxesColumn,
|
||||
Wrench
|
||||
} from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
@@ -21,6 +22,7 @@
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { modelName as serverModelName } from '$lib/stores/server.svelte';
|
||||
import { copyToClipboard } from '$lib/utils/copy';
|
||||
import type { ApiChatCompletionToolCall } from '$lib/types/api';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -51,6 +53,7 @@
|
||||
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||
textareaElement?: HTMLTextAreaElement;
|
||||
thinkingContent: string | null;
|
||||
toolCallContent: ApiChatCompletionToolCall[] | string | null;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -76,9 +79,15 @@
|
||||
shouldBranchAfterEdit = false,
|
||||
siblingInfo = null,
|
||||
textareaElement = $bindable(),
|
||||
thinkingContent
|
||||
thinkingContent,
|
||||
toolCallContent = null
|
||||
}: Props = $props();
|
||||
|
||||
const toolCalls = $derived(
|
||||
Array.isArray(toolCallContent) ? (toolCallContent as ApiChatCompletionToolCall[]) : null
|
||||
);
|
||||
const fallbackToolCalls = $derived(typeof toolCallContent === 'string' ? toolCallContent : null);
|
||||
|
||||
const processingState = useProcessingState();
|
||||
let currentConfig = $derived(config());
|
||||
let serverModel = $derived(serverModelName());
|
||||
@@ -97,6 +106,58 @@
|
||||
|
||||
void copyToClipboard(model ?? '');
|
||||
}
|
||||
|
||||
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
|
||||
const callNumber = index + 1;
|
||||
const functionName = toolCall.function?.name?.trim();
|
||||
const label = functionName || `Call #${callNumber}`;
|
||||
|
||||
const payload: Record<string, unknown> = {};
|
||||
|
||||
const id = toolCall.id?.trim();
|
||||
if (id) {
|
||||
payload.id = id;
|
||||
}
|
||||
|
||||
const type = toolCall.type?.trim();
|
||||
if (type) {
|
||||
payload.type = type;
|
||||
}
|
||||
|
||||
if (toolCall.function) {
|
||||
const fnPayload: Record<string, unknown> = {};
|
||||
|
||||
const name = toolCall.function.name?.trim();
|
||||
if (name) {
|
||||
fnPayload.name = name;
|
||||
}
|
||||
|
||||
const rawArguments = toolCall.function.arguments?.trim();
|
||||
if (rawArguments) {
|
||||
try {
|
||||
fnPayload.arguments = JSON.parse(rawArguments);
|
||||
} catch {
|
||||
fnPayload.arguments = rawArguments;
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(fnPayload).length > 0) {
|
||||
payload.function = fnPayload;
|
||||
}
|
||||
}
|
||||
|
||||
const formattedPayload = JSON.stringify(payload, null, 2);
|
||||
|
||||
return {
|
||||
label,
|
||||
tooltip: formattedPayload,
|
||||
copyValue: formattedPayload
|
||||
};
|
||||
}
|
||||
|
||||
function handleCopyToolCall(payload: string) {
|
||||
void copyToClipboard(payload, 'Tool call copied to clipboard');
|
||||
}
|
||||
</script>
|
||||
|
||||
<div
|
||||
@@ -189,6 +250,47 @@
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
{#if config().showToolCalls}
|
||||
{#if (toolCalls && toolCalls.length > 0) || fallbackToolCalls}
|
||||
<span class="inline-flex flex-wrap items-center gap-2 text-xs text-muted-foreground">
|
||||
<span class="inline-flex items-center gap-1">
|
||||
<Wrench class="h-3.5 w-3.5" />
|
||||
|
||||
<span>Tool calls:</span>
|
||||
</span>
|
||||
|
||||
{#if toolCalls && toolCalls.length > 0}
|
||||
{#each toolCalls as toolCall, index (toolCall.id ?? `${index}`)}
|
||||
{@const badge = formatToolCallBadge(toolCall, index)}
|
||||
<button
|
||||
type="button"
|
||||
class="tool-call-badge inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||
title={badge.tooltip}
|
||||
aria-label={`Copy tool call ${badge.label}`}
|
||||
onclick={() => handleCopyToolCall(badge.copyValue)}
|
||||
>
|
||||
{badge.label}
|
||||
|
||||
<Copy class="ml-1 h-3 w-3" />
|
||||
</button>
|
||||
{/each}
|
||||
{:else if fallbackToolCalls}
|
||||
<button
|
||||
type="button"
|
||||
class="tool-call-badge tool-call-badge--fallback inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||
title={fallbackToolCalls}
|
||||
aria-label="Copy tool call payload"
|
||||
onclick={() => handleCopyToolCall(fallbackToolCalls)}
|
||||
>
|
||||
{fallbackToolCalls}
|
||||
|
||||
<Copy class="ml-1 h-3 w-3" />
|
||||
</button>
|
||||
{/if}
|
||||
</span>
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
|
||||
{@const tokensPerSecond = (message.timings.predicted_n / message.timings.predicted_ms) * 1000}
|
||||
<span class="inline-flex items-center gap-2 text-xs text-muted-foreground">
|
||||
@@ -287,4 +389,17 @@
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.tool-call-badge {
|
||||
max-width: 12rem;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.tool-call-badge--fallback {
|
||||
max-width: 20rem;
|
||||
white-space: normal;
|
||||
word-break: break-word;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -76,10 +76,10 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="chat-processing-info-container" class:visible={showSlotsInfo}>
|
||||
<div class="chat-processing-info-container pointer-events-none" class:visible={showSlotsInfo}>
|
||||
<div class="chat-processing-info-content">
|
||||
{#each processingDetails as detail (detail)}
|
||||
<span class="chat-processing-info-detail">{detail}</span>
|
||||
<span class="chat-processing-info-detail pointer-events-auto">{detail}</span>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
@@ -92,7 +92,6 @@
|
||||
padding: 1.5rem 1rem;
|
||||
opacity: 0;
|
||||
transform: translateY(50%);
|
||||
pointer-events: none;
|
||||
transition:
|
||||
opacity 300ms ease-out,
|
||||
transform 300ms ease-out;
|
||||
@@ -100,7 +99,6 @@
|
||||
|
||||
.chat-processing-info-container.visible {
|
||||
opacity: 1;
|
||||
pointer-events: auto;
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
|
||||
@@ -333,7 +333,7 @@
|
||||
ondrop={handleDrop}
|
||||
role="main"
|
||||
>
|
||||
<div class="w-full max-w-2xl px-4">
|
||||
<div class="w-full max-w-[48rem] px-4">
|
||||
<div class="mb-8 text-center" in:fade={{ duration: 300 }}>
|
||||
<h1 class="mb-2 text-3xl font-semibold tracking-tight">llama.cpp</h1>
|
||||
|
||||
@@ -368,7 +368,7 @@
|
||||
<AlertDialog.Portal>
|
||||
<AlertDialog.Overlay />
|
||||
|
||||
<AlertDialog.Content class="max-w-md">
|
||||
<AlertDialog.Content class="flex max-w-md flex-col">
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title>File Upload Error</AlertDialog.Title>
|
||||
|
||||
@@ -377,7 +377,7 @@
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div class="!max-h-[50vh] min-h-0 flex-1 space-y-4 overflow-y-auto">
|
||||
{#if fileErrorData.generallyUnsupported.length > 0}
|
||||
<div class="space-y-2">
|
||||
<h4 class="text-sm font-medium text-destructive">Unsupported File Types</h4>
|
||||
@@ -398,8 +398,6 @@
|
||||
|
||||
{#if fileErrorData.modalityUnsupported.length > 0}
|
||||
<div class="space-y-2">
|
||||
<h4 class="text-sm font-medium text-destructive">Model Compatibility Issues</h4>
|
||||
|
||||
<div class="space-y-1">
|
||||
{#each fileErrorData.modalityUnsupported as file (file.name)}
|
||||
<div class="rounded-md bg-destructive/10 px-3 py-2">
|
||||
@@ -415,14 +413,14 @@
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div class="rounded-md bg-muted/50 p-3">
|
||||
<h4 class="mb-2 text-sm font-medium">This model supports:</h4>
|
||||
<div class="rounded-md bg-muted/50 p-3">
|
||||
<h4 class="mb-2 text-sm font-medium">This model supports:</h4>
|
||||
|
||||
<p class="text-sm text-muted-foreground">
|
||||
{fileErrorData.supportedTypes.join(', ')}
|
||||
</p>
|
||||
</div>
|
||||
<p class="text-sm text-muted-foreground">
|
||||
{fileErrorData.supportedTypes.join(', ')}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<AlertDialog.Footer>
|
||||
|
||||
@@ -226,6 +226,11 @@
|
||||
label: 'Enable model selector',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'showToolCalls',
|
||||
label: 'Show tool call labels',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'disableReasoningFormat',
|
||||
label: 'Show raw LLM output',
|
||||
|
||||
@@ -2,6 +2,7 @@ export { default as ChatAttachmentsList } from './chat/ChatAttachments/ChatAttac
|
||||
export { default as ChatAttachmentFilePreview } from './chat/ChatAttachments/ChatAttachmentFilePreview.svelte';
|
||||
export { default as ChatAttachmentImagePreview } from './chat/ChatAttachments/ChatAttachmentImagePreview.svelte';
|
||||
export { default as ChatAttachmentPreviewDialog } from './chat/ChatAttachments/ChatAttachmentPreviewDialog.svelte';
|
||||
export { default as ChatAttachmentsViewAllDialog } from './chat/ChatAttachments/ChatAttachmentsViewAllDialog.svelte';
|
||||
|
||||
export { default as ChatForm } from './chat/ChatForm/ChatForm.svelte';
|
||||
export { default as ChatFormTextarea } from './chat/ChatForm/ChatFormTextarea.svelte';
|
||||
@@ -42,6 +43,8 @@ export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.sve
|
||||
|
||||
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
|
||||
|
||||
export { default as RemoveButton } from './misc/RemoveButton.svelte';
|
||||
|
||||
export { default as ServerStatus } from './server/ServerStatus.svelte';
|
||||
export { default as ServerErrorSplash } from './server/ServerErrorSplash.svelte';
|
||||
export { default as ServerLoadingSplash } from './server/ServerLoadingSplash.svelte';
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
onRemove?: (id: string) => void;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { id, onRemove, class: className = '' }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 bg-white/20 p-0 hover:bg-white/30 {className}"
|
||||
onclick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(id);
|
||||
}}
|
||||
aria-label="Remove file"
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user