mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-29 17:17:40 +02:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 38e2c1b412 | |||
| cb44fc84e8 | |||
| cb623de3fc | |||
| 7aaeedc098 | |||
| 3347e6d904 | |||
| 1a139644a8 | |||
| 2376b7758c | |||
| dbed61294a | |||
| 80deff3648 | |||
| 8b1c339bd2 | |||
| 416e7c7f47 | |||
| 5b2093becc | |||
| 52e5d421f1 | |||
| 4db5641210 | |||
| 72bd7321a7 | |||
| 22e1ce2f81 | |||
| 1411d9275a | |||
| 662192e1dc | |||
| 24dc769f1b | |||
| 4dca015b7e | |||
| 9a8860cf5d | |||
| 9d3ef4809f | |||
| c7b7db0445 | |||
| 1568d13c2c | |||
| 439342ea0b | |||
| 234ae7d7bd | |||
| 38eaf32af1 | |||
| 9b17d74ab7 | |||
| e1fcf8b09b | |||
| 6cd0cf72ce |
@@ -1,52 +0,0 @@
|
||||
name: CI (AMD)
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'.github/workflows/build-amd.yml',
|
||||
'**/CMakeLists.txt',
|
||||
'**/.cmake',
|
||||
'**/*.h',
|
||||
'**/*.hpp',
|
||||
'**/*.c',
|
||||
'**/*.cpp',
|
||||
'**/*.cu',
|
||||
'**/*.cuh',
|
||||
'**/*.comp'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
@@ -1599,6 +1599,34 @@ jobs:
|
||||
run: |
|
||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-metal:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
|
||||
+1
-5
@@ -355,11 +355,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
void common_init() {
|
||||
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}, NULL);
|
||||
llama_log_set(common_log_default_callback, NULL);
|
||||
|
||||
#ifdef NDEBUG
|
||||
const char * build_type = "";
|
||||
|
||||
@@ -442,3 +442,9 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||
log->set_timestamps(timestamps);
|
||||
}
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
||||
common_log_add(common_log_main(), level, "%s", text);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,8 @@ extern int common_log_verbosity_thold;
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||
|
||||
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
||||
|
||||
// the common_log uses an internal worker thread to print/write log messages
|
||||
// when the worker thread is paused, incoming log messages are discarded
|
||||
struct common_log;
|
||||
|
||||
+81
-43
@@ -189,10 +189,10 @@ class ModelBase:
|
||||
return tensors
|
||||
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
|
||||
is_safetensors: bool = len(part_names) > 0
|
||||
if not is_safetensors:
|
||||
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
|
||||
|
||||
tensor_names_from_index: set[str] = set()
|
||||
|
||||
@@ -209,6 +209,7 @@ class ModelBase:
|
||||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||
tensor_names_from_index.update(weight_map.keys())
|
||||
part_names |= set(weight_map.values())
|
||||
else:
|
||||
weight_map = {}
|
||||
else:
|
||||
@@ -825,6 +826,15 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_expert_group_used_count(n_group_used)
|
||||
logger.info(f"gguf: expert groups used count = {n_group_used}")
|
||||
|
||||
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func"], optional=True)) is not None:
|
||||
if score_func == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif score_func == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
|
||||
logger.info(f"gguf: expert score gating function = {score_func}")
|
||||
|
||||
if (head_dim := self.hparams.get("head_dim")) is not None:
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
self.gguf_writer.add_value_length(head_dim)
|
||||
@@ -1124,6 +1134,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
|
||||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
|
||||
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
|
||||
res = "afmoe"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
|
||||
res = "bailingmoe2"
|
||||
@@ -2533,6 +2546,72 @@ class ArceeModel(LlamaModel):
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
|
||||
@ModelBase.register("AfmoeForCausalLM")
|
||||
class AfmoeModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.AFMOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# MoE parameters
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
|
||||
self.gguf_writer.add_expert_shared_count(n_shared_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
|
||||
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
|
||||
|
||||
# Route normalization and scaling
|
||||
if (route_norm := self.hparams.get("route_norm")) is not None:
|
||||
self.gguf_writer.add_expert_weights_norm(route_norm)
|
||||
if (route_scale := self.hparams.get("route_scale")) is not None:
|
||||
self.gguf_writer.add_expert_weights_scale(route_scale)
|
||||
|
||||
# Sliding window attention
|
||||
if (sliding_window := self.hparams.get("sliding_window")) is not None:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Handle expert weights - they're already merged in the HF format
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["gate_proj", "up_proj", "down_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename_to_retrieve])
|
||||
del self._experts[bid][ename_to_retrieve]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"LlavaForConditionalGeneration", # pixtral
|
||||
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||
@@ -7104,13 +7183,6 @@ class DeepseekV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["scoring_func"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["scoring_func"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
@@ -7216,12 +7288,6 @@ class MiniMaxM2Model(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hparams["scoring_func"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif self.hparams["scoring_func"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
|
||||
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
|
||||
@@ -7314,11 +7380,6 @@ class Dots1Model(Qwen2MoeModel):
|
||||
self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
|
||||
|
||||
if self.hparams["scoring_func"] == "noaux_tc":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
@@ -7779,12 +7840,6 @@ class Glm4MoeModel(TextModel):
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
|
||||
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
|
||||
|
||||
# Patch broken chat template
|
||||
if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
|
||||
special_vocab.chat_template = special_vocab.chat_template.replace(
|
||||
"""{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
|
||||
"""{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
|
||||
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
@@ -8639,13 +8694,6 @@ class BailingMoeV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["score_function"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["score_function"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
|
||||
|
||||
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
|
||||
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
|
||||
|
||||
@@ -9341,16 +9389,6 @@ class HunYuanModel(TextModel):
|
||||
class SmolLM3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMOLLM3
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# remove unsupported array slicing in chat template
|
||||
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
if tokenizer.chat_template is not None:
|
||||
chat_template = tokenizer.chat_template.replace("[:]", "")
|
||||
self.gguf_writer.add_chat_template(chat_template)
|
||||
|
||||
|
||||
@ModelBase.register("GptOssForCausalLM")
|
||||
class GptOssModel(TextModel):
|
||||
|
||||
@@ -139,6 +139,7 @@ models = [
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||
|
||||
+35
-36
@@ -14,24 +14,24 @@ Legend:
|
||||
|
||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
||||
|-----------|------|------|------|------|------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -40,8 +40,8 @@ Legend:
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
@@ -50,40 +50,40 @@ Legend:
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
@@ -93,29 +93,28 @@ Legend:
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
|
||||
+2348
-149
File diff suppressed because it is too large
Load Diff
+14536
-4360
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -48,15 +48,14 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
||||
default:
|
||||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
// If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
|
||||
// added.
|
||||
int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
|
||||
@@ -87,10 +86,20 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
std::reverse(acl_ne, acl_ne + final_dims);
|
||||
std::reverse(acl_stride, acl_stride + final_dims);
|
||||
|
||||
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
||||
elem_offset, format, &acl_storage_len, 1, tensor->data);
|
||||
aclTensor * raw = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, elem_offset,
|
||||
format, &acl_storage_len, 1, tensor->data);
|
||||
|
||||
return acl_tensor;
|
||||
return acl_tensor_ptr(raw);
|
||||
}
|
||||
|
||||
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size) {
|
||||
aclIntArray * raw = aclCreateIntArray(value, size);
|
||||
return acl_int_array_ptr(raw);
|
||||
}
|
||||
|
||||
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType) {
|
||||
aclScalar * raw = aclCreateScalar(value, dataType);
|
||||
return acl_scalar_ptr(raw);
|
||||
}
|
||||
|
||||
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
|
||||
|
||||
@@ -23,11 +23,12 @@
|
||||
#ifndef CANN_ACL_TENSOR_H
|
||||
#define CANN_ACL_TENSOR_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include "common.h"
|
||||
|
||||
#include <aclnn/aclnn_base.h>
|
||||
#include "common.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
/**
|
||||
* @brief Maps a ggml_type to its corresponding aclDataType.
|
||||
@@ -43,6 +44,20 @@
|
||||
*/
|
||||
aclDataType ggml_cann_type_mapping(ggml_type type);
|
||||
|
||||
// Deleter for acl objects.
|
||||
template <typename T, aclError (*DestroyFunc)(const T *)> struct acl_deleter {
|
||||
void operator()(T * ptr) const noexcept {
|
||||
if (ptr) {
|
||||
ACL_CHECK(DestroyFunc(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using acl_tensor_ptr = std::unique_ptr<aclTensor, acl_deleter<aclTensor, aclDestroyTensor>>;
|
||||
using acl_int_array_ptr = std::unique_ptr<aclIntArray, acl_deleter<aclIntArray, aclDestroyIntArray>>;
|
||||
using acl_scalar_ptr = std::unique_ptr<aclScalar, acl_deleter<aclScalar, aclDestroyScalar>>;
|
||||
using acl_tensor_list_ptr = std::unique_ptr<aclTensorList, acl_deleter<aclTensorList, aclDestroyTensorList>>;
|
||||
|
||||
/**
|
||||
* @brief Creates an ACL tensor from a ggml_tensor with optional shape.
|
||||
*
|
||||
@@ -62,12 +77,12 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
|
||||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
acl_tensor_ptr ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
|
||||
/**
|
||||
* @brief Template for creating an ACL tensor from provided parameters. typename TYPE
|
||||
@@ -90,14 +105,14 @@ aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
template <typename TYPE>
|
||||
aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
acl_tensor_ptr ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
int64_t tmp_ne[GGML_MAX_DIMS * 2];
|
||||
int64_t tmp_stride[GGML_MAX_DIMS * 2];
|
||||
|
||||
@@ -114,10 +129,75 @@ aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||
std::reverse(tmp_ne, tmp_ne + dims);
|
||||
std::reverse(tmp_stride, tmp_stride + dims);
|
||||
|
||||
aclTensor * acl_tensor =
|
||||
aclTensor * raw =
|
||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
|
||||
|
||||
return acl_tensor;
|
||||
return acl_tensor_ptr(raw);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create an ACL int array resource wrapped in a smart pointer.
|
||||
*
|
||||
* This function constructs an aclIntArray from the provided int64_t values
|
||||
* and returns it as an acl_int_array_ptr (a std::unique_ptr with a custom
|
||||
* deleter). The returned pointer owns the ACL resource and will automatically
|
||||
* destroy it via aclDestroyIntArray().
|
||||
*
|
||||
* @param value Pointer to the int64_t elements.
|
||||
* @param size Number of elements in value.
|
||||
*
|
||||
* @return A smart pointer managing the created ACL int array.
|
||||
*/
|
||||
acl_int_array_ptr ggml_cann_create_int_array(const int64_t * value, uint64_t size);
|
||||
|
||||
/**
|
||||
* @brief Create an ACL scalar resource wrapped in a smart pointer.
|
||||
*
|
||||
* This function constructs an aclScalar from the raw value pointer and ACL
|
||||
* data type, then returns it as an acl_scalar_ptr (a std::unique_ptr with
|
||||
* a custom deleter). The returned pointer owns the ACL scalar and will
|
||||
* automatically destroy it via aclDestroyScalar().
|
||||
*
|
||||
* @param value Pointer to the raw scalar memory.
|
||||
* @param dataType ACL data type of the scalar.
|
||||
*
|
||||
* @return A smart pointer managing the created ACL scalar.
|
||||
*/
|
||||
acl_scalar_ptr ggml_cann_create_scalar(void * value, aclDataType dataType);
|
||||
|
||||
/**
|
||||
* @brief Create an ACL tensor list from multiple tensor smart pointers.
|
||||
*
|
||||
* This function accepts a variadic list of acl_tensor_ptr (a unique_ptr with
|
||||
* custom deleter) and produces an aclTensorList using aclCreateTensorList().
|
||||
*
|
||||
* The lifecycle management of the tensor objects changes as follows:
|
||||
* - aclCreateTensorList() takes ownership of the tensors
|
||||
* - Each input smart pointer releases ownership using release()
|
||||
* - As a result, the tensors will NOT be destroyed by unique_ptr
|
||||
* - Instead, they will be destroyed when aclDestroyTensorList() is called
|
||||
*
|
||||
* This ensures correct ownership transfer and prevents double-free situations.
|
||||
*
|
||||
* @param acl_tensor_ptr Variadic template parameter; each argument must be
|
||||
* a unique_ptr-like type supporting get() and release().
|
||||
*
|
||||
* @param tensors Variadic list of acl_tensor_ptr objects. Ownership of
|
||||
* each tensor is transferred away from these smart pointers.
|
||||
*
|
||||
* @return A smart pointer (acl_tensor_list_ptr) owning the created ACL tensor list.
|
||||
*
|
||||
* @note This implementation is C++11 compatible. The ownership-release process is
|
||||
* executed using a pack expansion inside an initializer list.
|
||||
*/
|
||||
template <typename... acl_tensor_ptr> acl_tensor_list_ptr ggml_cann_create_tensor_list(acl_tensor_ptr &&... tensors) {
|
||||
aclTensor * raw_tensors[] = { tensors.get()... };
|
||||
aclTensorList * raw = aclCreateTensorList(raw_tensors, sizeof...(tensors));
|
||||
// aclTensor will release by aclTensorList, so release ownership without
|
||||
// destroying the tensor
|
||||
int dummy[] = { (tensors.release(), 0)... };
|
||||
GGML_UNUSED(dummy);
|
||||
return acl_tensor_list_ptr(raw);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
+587
-691
File diff suppressed because it is too large
Load Diff
+42
-197
@@ -23,33 +23,35 @@
|
||||
#ifndef CANN_ACLNN_OPS
|
||||
#define CANN_ACLNN_OPS
|
||||
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <aclnnop/aclnn_abs.h>
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
#include <aclnnop/aclnn_arange.h>
|
||||
#include <aclnnop/aclnn_argsort.h>
|
||||
#include <aclnnop/aclnn_cat.h>
|
||||
#include <aclnnop/aclnn_clamp.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_exp.h>
|
||||
#include <aclnnop/aclnn_gelu.h>
|
||||
#include <aclnnop/aclnn_gelu_v2.h>
|
||||
#include <aclnnop/aclnn_sigmoid.h>
|
||||
#include <aclnnop/aclnn_hardsigmoid.h>
|
||||
#include <aclnnop/aclnn_hardswish.h>
|
||||
#include <aclnnop/aclnn_leaky_relu.h>
|
||||
#include <aclnnop/aclnn_relu.h>
|
||||
#include <aclnnop/aclnn_silu.h>
|
||||
#include <aclnnop/aclnn_tanh.h>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_cos.h>
|
||||
#include <aclnnop/aclnn_log.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_norm.h>
|
||||
#include <aclnnop/aclnn_logsoftmax.h>
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
#include <aclnnop/aclnn_neg.h>
|
||||
#include <aclnnop/aclnn_norm.h>
|
||||
#include <aclnnop/aclnn_relu.h>
|
||||
#include <aclnnop/aclnn_sigmoid.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_silu.h>
|
||||
#include <aclnnop/aclnn_sin.h>
|
||||
#include <aclnnop/aclnn_sqrt.h>
|
||||
#include <aclnnop/aclnn_tanh.h>
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_set>
|
||||
|
||||
/**
|
||||
* @brief Repeats a ggml tensor along each dimension to match the dimensions
|
||||
@@ -688,12 +690,12 @@ void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor *
|
||||
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
|
||||
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
|
||||
*/
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
aclTensor ** acl_src0,
|
||||
aclTensor ** acl_src1,
|
||||
aclTensor ** acl_dst);
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
acl_tensor_ptr & acl_src0,
|
||||
acl_tensor_ptr & acl_src1,
|
||||
acl_tensor_ptr & acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
|
||||
@@ -873,83 +875,6 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
|
||||
(vec.emplace_back(make_acl_resource(args)), ...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Task class that wraps the execution of an aclnn function call.
|
||||
*/
|
||||
class aclnn_task : public cann_task {
|
||||
public:
|
||||
aclnn_task(aclnn_func_t aclnn_func,
|
||||
void * workspace_addr,
|
||||
uint64_t workspace_size,
|
||||
aclOpExecutor * executor,
|
||||
aclrtStream stream) :
|
||||
aclnn_func_(aclnn_func),
|
||||
workspace_addr_(workspace_addr),
|
||||
workspace_size_(workspace_size),
|
||||
executor_(executor),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
|
||||
private:
|
||||
aclnn_func_t aclnn_func_;
|
||||
void * workspace_addr_;
|
||||
uint64_t workspace_size_;
|
||||
aclOpExecutor * executor_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class that releases ACL resources after usage.
|
||||
*/
|
||||
class release_resource_task : public cann_task {
|
||||
public:
|
||||
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
|
||||
|
||||
virtual void run_task() override { resource_.clear(); }
|
||||
private:
|
||||
std::vector<any_acl_resource> resource_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory copy operations.
|
||||
*/
|
||||
class async_memcpy_task : public cann_task {
|
||||
public:
|
||||
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
|
||||
dst_(dst),
|
||||
src_(src),
|
||||
size_(size),
|
||||
kind_(kind),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
|
||||
private:
|
||||
void * dst_;
|
||||
const void * src_;
|
||||
size_t size_;
|
||||
aclrtMemcpyKind kind_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory set operations.
|
||||
*/
|
||||
class async_memset_task : public cann_task {
|
||||
public:
|
||||
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
|
||||
buffer_(buffer),
|
||||
size_(size),
|
||||
value_(value),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
|
||||
private:
|
||||
void * buffer_;
|
||||
size_t size_;
|
||||
int32_t value_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Launches an asynchronous task using the memory allocator.
|
||||
*
|
||||
@@ -968,95 +893,20 @@ class async_memset_task : public cann_task {
|
||||
* same stream are executed in queue order.
|
||||
*/
|
||||
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
if (CTX.async_mode) { \
|
||||
auto task = \
|
||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
|
||||
CTX.task_queue.submit_task(std::move(task)); \
|
||||
} else { \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
} \
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Registers and releases multiple ACL resources, optionally deferring the release
|
||||
* using a task.
|
||||
*
|
||||
* @tparam Args Types of the ACL resources.
|
||||
* @param ctx Backend context which manages task submission and async mode.
|
||||
* @param args Pointers to ACL resources to be released.
|
||||
*/
|
||||
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
||||
std::vector<any_acl_resource> resources;
|
||||
register_acl_resources(resources, std::forward<Args>(args)...);
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<release_resource_task>(std::move(resources));
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param dst Destination memory address.
|
||||
* @param src Source memory address.
|
||||
* @param len Size of memory to copy (in bytes).
|
||||
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
|
||||
*/
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx->async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
|
||||
ctx->task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
|
||||
*
|
||||
* @param ctx Backend context containing stream and async configuration.
|
||||
* @param buffer Memory buffer to be set.
|
||||
* @param size Size of the memory buffer (in bytes).
|
||||
* @param value Value to set in the buffer.
|
||||
*/
|
||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
} else {
|
||||
ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
|
||||
*
|
||||
@@ -1129,15 +979,11 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
aclTensor * acl_src0;
|
||||
aclTensor * acl_src1;
|
||||
aclTensor * acl_dst;
|
||||
acl_tensor_ptr acl_src0, acl_src1, acl_dst;
|
||||
|
||||
// Need bcast
|
||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
||||
binary_op(ctx, acl_src0, acl_src1, acl_dst);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
bcast_shape(src0, src1, dst, acl_src0, acl_src1, acl_dst);
|
||||
binary_op(ctx, acl_src0.get(), acl_src1.get(), acl_dst.get());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1147,7 +993,7 @@ template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & c
|
||||
* and stores the result in the destination tensor.
|
||||
*
|
||||
* @tparam unary_op A callable with the signature:
|
||||
* void(ggml_backend_cann_context&, aclTensor*, aclTensor*)
|
||||
* void(ggml_backend_cann_context&, aclTensor *, aclTensor *)
|
||||
* where the first aclTensor is the source and the second is the destination.
|
||||
* @param ctx The CANN backend context for managing resources and execution.
|
||||
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
|
||||
@@ -1156,11 +1002,10 @@ template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
|
||||
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
aclTensor * acl_src = ggml_cann_create_tensor(src);
|
||||
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
|
||||
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
unary_op(ctx, acl_src, acl_dst);
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
||||
unary_op(ctx, acl_src.get(), acl_dst.get());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
+19
-150
@@ -23,26 +23,26 @@
|
||||
#ifndef CANN_COMMON_H
|
||||
#define CANN_COMMON_H
|
||||
|
||||
#include <acl/acl.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <list>
|
||||
|
||||
#include "../ggml-impl.h"
|
||||
#include "../include/ggml-cann.h"
|
||||
#include "../include/ggml.h"
|
||||
#include "../ggml-impl.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#define MATRIX_ROW_PADDING 512
|
||||
#define GGML_CANN_MAX_STREAMS 8
|
||||
@@ -214,130 +214,6 @@ struct ggml_cann_pool_alloc {
|
||||
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Function pointer type for ACLNN operator calls.
|
||||
*/
|
||||
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
|
||||
|
||||
/**
|
||||
* @brief Base class for all CANN tasks to be submitted to the task queue.
|
||||
*
|
||||
* Users should override the run_task() method with actual task logic.
|
||||
*/
|
||||
class cann_task {
|
||||
public:
|
||||
virtual void run_task() {}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
|
||||
*/
|
||||
class cann_task_queue {
|
||||
public:
|
||||
/**
|
||||
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
|
||||
*
|
||||
* @param capacity Queue capacity. Must be a power of 2.
|
||||
* @param device Target device ID (used for context setting).
|
||||
*/
|
||||
explicit cann_task_queue(size_t capacity, int32_t device) :
|
||||
buffer_(capacity),
|
||||
capacity_(capacity),
|
||||
head_(0),
|
||||
tail_(0),
|
||||
running_(false),
|
||||
device_(device) {
|
||||
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
|
||||
mask_ = capacity_ - 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Attempts to enqueue a task into the queue.
|
||||
*
|
||||
* @param item Unique pointer to the task.
|
||||
* @return true if the task was successfully enqueued, false if the queue was full.
|
||||
*/
|
||||
bool enqueue(std::unique_ptr<cann_task> && item) {
|
||||
size_t next_tail = (tail_ + 1) & mask_;
|
||||
|
||||
if (next_tail == head_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
buffer_[tail_] = std::move(item);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
tail_ = next_tail;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Submits a task to the queue, and starts the worker thread if not already running.
|
||||
*
|
||||
* @param task Task to be submitted.
|
||||
*/
|
||||
void submit_task(std::unique_ptr<cann_task> && task) {
|
||||
while (!enqueue(std::move(task))) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!running_) {
|
||||
running_ = true;
|
||||
thread_ = std::thread(&cann_task_queue::execute, this);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits until the queue is completely empty and no tasks are being processed.
|
||||
*/
|
||||
void wait() {
|
||||
while (running_ && head_ != tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stops the task queue and joins the worker thread.
|
||||
*/
|
||||
void stop() {
|
||||
running_ = false;
|
||||
if (thread_.joinable()) {
|
||||
thread_.join();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Worker thread function that continuously dequeues and executes tasks.
|
||||
*/
|
||||
void execute() {
|
||||
ggml_cann_set_device(device_);
|
||||
|
||||
while (running_) {
|
||||
if (head_ == tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
buffer_[head_]->run_task();
|
||||
buffer_[head_].reset();
|
||||
head_ = (head_ + 1) & mask_;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<cann_task>> buffer_;
|
||||
const size_t capacity_;
|
||||
size_t mask_;
|
||||
size_t head_;
|
||||
size_t tail_;
|
||||
bool running_;
|
||||
std::thread thread_;
|
||||
int32_t device_;
|
||||
};
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
struct ggml_graph_node_properties {
|
||||
// dst tensor
|
||||
@@ -474,7 +350,6 @@ struct ggml_backend_cann_context {
|
||||
ggml_cann_graph_lru_cache graph_lru_cache;
|
||||
bool acl_graph_mode = true;
|
||||
#endif
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
// Rope Cache
|
||||
ggml_cann_rope_cache rope_cache;
|
||||
@@ -488,15 +363,10 @@ struct ggml_backend_cann_context {
|
||||
* @brief Constructor for initializing the context with a given device.
|
||||
* @param device Device ID.
|
||||
*/
|
||||
explicit ggml_backend_cann_context(int device) :
|
||||
device(device),
|
||||
name("CANN" + std::to_string(device)),
|
||||
task_queue(1024, device) {
|
||||
explicit ggml_backend_cann_context(int device) : device(device), name("CANN" + std::to_string(device)) {
|
||||
ggml_cann_set_device(device);
|
||||
description = aclrtGetSocName();
|
||||
|
||||
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
|
||||
#ifdef USE_ACL_GRAPH
|
||||
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
|
||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
@@ -509,7 +379,6 @@ struct ggml_backend_cann_context {
|
||||
*/
|
||||
~ggml_backend_cann_context() {
|
||||
ggml_cann_set_device(device);
|
||||
task_queue.stop();
|
||||
if (copy_event != nullptr) {
|
||||
ACL_CHECK(aclrtDestroyEvent(copy_event));
|
||||
}
|
||||
|
||||
@@ -22,24 +22,24 @@
|
||||
|
||||
#include "ggml-cann.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <stdarg.h>
|
||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-cann/aclnn_ops.h"
|
||||
#include "ggml-cann/common.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <acl/acl.h>
|
||||
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
||||
#include <stdarg.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
#include <unordered_set>
|
||||
#include <optional>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-cann/aclnn_ops.h"
|
||||
#include "ggml-cann/common.h"
|
||||
#include "ggml.h"
|
||||
#include <queue>
|
||||
#include <unordered_set>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
|
||||
@@ -1177,19 +1177,18 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
|
||||
* across calls. This reduces overhead from repeated memory allocation and deallocation.
|
||||
*/
|
||||
static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
|
||||
aclTensor * weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||
uint64_t workspaceSize = 0;
|
||||
acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||
uint64_t workspaceSize = 0;
|
||||
aclOpExecutor * executor;
|
||||
|
||||
// TransMatmulWeight
|
||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
|
||||
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
|
||||
// Avoid frequent malloc/free of the workspace.
|
||||
g_nz_workspaces[device].realloc(workspaceSize);
|
||||
|
||||
void * g_nz_workspace = g_nz_workspaces[device].get();
|
||||
|
||||
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
|
||||
ACL_CHECK(aclDestroyTensor(weightTransposed));
|
||||
}
|
||||
|
||||
// TODO: need handle tensor which has paddings.
|
||||
@@ -1641,7 +1640,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
|
||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||
},
|
||||
/* .device = */
|
||||
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
@@ -1949,7 +1948,8 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ggml_cann_async_memcpy(cann_ctx, (char *) tensor->data + offset, data, size, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
|
||||
cann_ctx->stream()));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1974,7 +1974,8 @@ static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
|
||||
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
|
||||
GGML_ASSERT(!ggml_is_quantized(tensor->type));
|
||||
|
||||
ggml_cann_async_memcpy(cann_ctx, data, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
|
||||
cann_ctx->stream()));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -2035,7 +2036,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
|
||||
|
||||
// wait for task_queue empty to keep task order.
|
||||
cann_ctx_src->task_queue.wait();
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
|
||||
cann_ctx_src->stream()));
|
||||
// record event on src stream after the copy
|
||||
@@ -2068,7 +2068,6 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
*/
|
||||
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
||||
ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
|
||||
cann_ctx->task_queue.wait();
|
||||
ggml_cann_set_device(cann_ctx->device);
|
||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
||||
}
|
||||
@@ -2485,6 +2484,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] > 896) {
|
||||
return false;
|
||||
}
|
||||
#ifdef ASCEND_310P
|
||||
if (!ggml_is_contiguous(op->src[0])) {
|
||||
return false;
|
||||
@@ -2521,10 +2523,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
// value of paddingW should be at most half of kernelW
|
||||
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
|
||||
}
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_REPEAT:
|
||||
|
||||
@@ -145,26 +145,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
include(CheckCXXSourceRuns)
|
||||
|
||||
function(check_arm_feature tag code)
|
||||
macro(check_arm_feature tag feature code)
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
|
||||
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
|
||||
if (GGML_MACHINE_SUPPORTS_${tag})
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}" PARENT_SCOPE)
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}")
|
||||
else()
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
|
||||
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
|
||||
if (GGML_MACHINE_SUPPORTS_no${tag})
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}" PARENT_SCOPE)
|
||||
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}")
|
||||
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_${feature})
|
||||
endif()
|
||||
endif()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
endfunction()
|
||||
endmacro()
|
||||
|
||||
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
||||
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
||||
check_arm_feature(dotprod DOTPROD "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(i8mm MATMUL_INT8 "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
||||
check_arm_feature(sve SVE "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
||||
check_arm_feature(sme SME "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
|
||||
|
||||
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
|
||||
else()
|
||||
@@ -216,35 +217,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# show enabled features
|
||||
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
||||
set(FEAT_INPUT_FILE "NUL")
|
||||
else()
|
||||
set(FEAT_INPUT_FILE "/dev/null")
|
||||
endif()
|
||||
message(STATUS "Checking for ARM features using flags:")
|
||||
foreach(flag IN LISTS ARCH_FLAGS)
|
||||
message(STATUS " ${flag}")
|
||||
endforeach()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E -
|
||||
INPUT_FILE ${FEAT_INPUT_FILE}
|
||||
OUTPUT_VARIABLE ARM_FEATURE
|
||||
RESULT_VARIABLE ARM_FEATURE_RESULT
|
||||
)
|
||||
if (ARM_FEATURE_RESULT)
|
||||
message(WARNING "Failed to get ARM features")
|
||||
else()
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
||||
if (NOT ${feature_pos} EQUAL -1)
|
||||
# Special handling for MATMUL_INT8 when machine doesn't support i8mm
|
||||
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
|
||||
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
|
||||
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
|
||||
else()
|
||||
message(STATUS "ARM feature ${feature} enabled")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
include(CheckCXXSourceCompiles)
|
||||
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
||||
set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}")
|
||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||
set(ARM_FEATURE "HAVE_${feature}")
|
||||
check_cxx_source_compiles(
|
||||
"
|
||||
#if !defined(__ARM_FEATURE_${feature})
|
||||
# error \"Feature ${feature} is not defined\"
|
||||
#endif
|
||||
int main() { return 0; }
|
||||
"
|
||||
${ARM_FEATURE}
|
||||
)
|
||||
endforeach()
|
||||
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||
message(STATUS "x86 detected")
|
||||
|
||||
@@ -646,7 +646,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||
int64_t xstart = 0;
|
||||
int anr = nr - nr%16; // Used to align nr with boundary of 16
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
int anc = nc - nc%16; // Used to align nc with boundary of 16
|
||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||
@@ -1041,7 +1041,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
xstart = anc/8;
|
||||
y = 0;
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
|
||||
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
||||
|
||||
@@ -1989,7 +1989,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
||||
int64_t xstart = 0;
|
||||
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
||||
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
||||
@@ -2727,7 +2727,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
xstart = anc/8;
|
||||
y = 0;
|
||||
}
|
||||
#endif //AVX512F
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
|
||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||
for (; y < anr / 4; y += 4) {
|
||||
@@ -3467,7 +3467,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
__m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);
|
||||
scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512BW__) && defined(__AVX512DQ__)
|
||||
|
||||
int anc = nc - nc % 16; // Used to align nc with boundary of 16
|
||||
|
||||
@@ -4947,7 +4947,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
y = 0;
|
||||
}
|
||||
|
||||
#endif //AVX512F
|
||||
#endif // __AVX512BW__ && __AVX512DQ__
|
||||
|
||||
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
||||
for (; y < anr / 4; y += 4) {
|
||||
|
||||
@@ -318,6 +318,44 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
|
||||
@@ -113,6 +113,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
@@ -870,6 +870,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_SUM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
@@ -988,7 +989,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return false;
|
||||
}
|
||||
case GGML_TYPE_I32:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32;
|
||||
default:
|
||||
return false;
|
||||
};
|
||||
|
||||
@@ -612,6 +612,45 @@ typedef struct {
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_sum_rows;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t net0;
|
||||
int64_t net1;
|
||||
int64_t net2;
|
||||
int64_t net3;
|
||||
uint64_t nbt0;
|
||||
uint64_t nbt1;
|
||||
uint64_t nbt2;
|
||||
uint64_t nbt3;
|
||||
bool outb;
|
||||
} ggml_metal_kargs_cumsum_blk;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t net0;
|
||||
int64_t net1;
|
||||
int64_t net2;
|
||||
int64_t net3;
|
||||
uint64_t nbt0;
|
||||
uint64_t nbt1;
|
||||
uint64_t nbt2;
|
||||
uint64_t nbt3;
|
||||
} ggml_metal_kargs_cumsum_add;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
|
||||
@@ -311,6 +311,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_sum_rows(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_CUMSUM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_cumsum(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
n_fuse = ggml_metal_op_soft_max(ctx, idx);
|
||||
@@ -539,7 +543,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
|
||||
|
||||
@@ -585,7 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
@@ -694,7 +698,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float scale;
|
||||
float bias;
|
||||
@@ -733,7 +737,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float min;
|
||||
float max;
|
||||
@@ -772,7 +776,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
@@ -802,7 +806,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
if (op->src[1]) {
|
||||
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
|
||||
@@ -834,18 +838,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
|
||||
|
||||
//[encoder setComputePipelineState:pipeline];
|
||||
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
//if (src1) {
|
||||
// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
//} else {
|
||||
// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
//}
|
||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
//[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
//[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
@@ -907,7 +899,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -941,14 +933,6 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
|
||||
//[encoder setComputePipelineState:pipeline];
|
||||
//[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
//[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
@@ -961,6 +945,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
|
||||
|
||||
int nth = 1;
|
||||
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne00 <= nth*nth);
|
||||
|
||||
const int64_t net0 = (ne00 + nth - 1) / nth;
|
||||
const int64_t net1 = ne01;
|
||||
const int64_t net2 = ne02;
|
||||
const int64_t net3 = ne03;
|
||||
|
||||
const uint64_t nbt0 = sizeof(float);
|
||||
const uint64_t nbt1 = net0*nbt0;
|
||||
const uint64_t nbt2 = net1*nbt1;
|
||||
const uint64_t nbt3 = net2*nbt2;
|
||||
|
||||
const size_t smem = GGML_PAD(32*sizeof(float), 16);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += ggml_nbytes(op);
|
||||
|
||||
{
|
||||
ggml_metal_kargs_cumsum_blk args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.net0 =*/ net0,
|
||||
/*.net1 =*/ net1,
|
||||
/*.net2 =*/ net2,
|
||||
/*.net3 =*/ net3,
|
||||
/*.nbt0 =*/ nbt0,
|
||||
/*.nbt1 =*/ nbt1,
|
||||
/*.nbt2 =*/ nbt2,
|
||||
/*.nbt3 =*/ nbt3,
|
||||
/*.outb =*/ ne00 > nth,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
||||
}
|
||||
|
||||
if (ne00 > nth) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_kargs_cumsum_blk args = {
|
||||
/*.ne00 =*/ net0,
|
||||
/*.ne01 =*/ net1,
|
||||
/*.ne02 =*/ net2,
|
||||
/*.ne03 =*/ net3,
|
||||
/*.nb00 =*/ nbt0,
|
||||
/*.nb01 =*/ nbt1,
|
||||
/*.nb02 =*/ nbt2,
|
||||
/*.nb03 =*/ nbt3,
|
||||
/*.net0 =*/ net0,
|
||||
/*.net1 =*/ net1,
|
||||
/*.net2 =*/ net2,
|
||||
/*.net3 =*/ net3,
|
||||
/*.nbt0 =*/ nbt0,
|
||||
/*.nbt1 =*/ nbt1,
|
||||
/*.nbt2 =*/ nbt2,
|
||||
/*.nbt3 =*/ nbt3,
|
||||
/*.outb =*/ false,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
|
||||
}
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
|
||||
|
||||
ggml_metal_kargs_cumsum_add args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.net0 =*/ net0,
|
||||
/*.net1 =*/ net1,
|
||||
/*.net2 =*/ net2,
|
||||
/*.net3 =*/ net3,
|
||||
/*.nbt0 =*/ nbt0,
|
||||
/*.nbt1 =*/ nbt1,
|
||||
/*.nbt2 =*/ nbt2,
|
||||
/*.nbt3 =*/ nbt3,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_add);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -972,7 +1099,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
||||
|
||||
@@ -1017,7 +1144,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
|
||||
|
||||
@@ -1081,7 +1208,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
@@ -1169,7 +1296,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_ssm_conv args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1224,7 +1351,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const ggml_tensor * src3 = op->src[3];
|
||||
const ggml_tensor * src4 = op->src[4];
|
||||
@@ -1310,7 +1437,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
|
||||
const int64_t T = op->src[0]->ne[2];
|
||||
@@ -1351,7 +1478,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
@@ -1424,7 +1551,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t * opts = op->op_params;
|
||||
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
|
||||
@@ -1488,7 +1615,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
|
||||
@@ -1729,7 +1856,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
// src2 = ids
|
||||
GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
|
||||
@@ -2191,8 +2318,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
@@ -2222,8 +2347,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
|
||||
}
|
||||
|
||||
if (need_sync) {
|
||||
@@ -2363,8 +2486,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
if (need_sync) {
|
||||
@@ -2695,7 +2816,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, op->op_params, sizeof(float));
|
||||
@@ -2743,7 +2864,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
|
||||
|
||||
@@ -2798,7 +2919,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, op->op_params, sizeof(float));
|
||||
@@ -2934,7 +3055,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
// make sure we have one or more position id(ne10) per token(ne02)
|
||||
GGML_ASSERT(ne10 % ne02 == 0);
|
||||
@@ -3028,7 +3149,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
|
||||
@@ -3178,7 +3299,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||
|
||||
@@ -3223,7 +3344,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||
|
||||
@@ -3277,7 +3398,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const float sf0 = (float)ne0/op->src[0]->ne[0];
|
||||
const float sf1 = (float)ne1/op->src[0]->ne[1];
|
||||
@@ -3330,7 +3451,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_pad args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -3374,7 +3495,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_pad_reflect_1d args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -3418,7 +3539,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float start;
|
||||
float step;
|
||||
@@ -3436,12 +3557,6 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
|
||||
|
||||
//[encoder setComputePipelineState:pipeline];
|
||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
||||
//[encoder setBytes:&args length:sizeof(args) atIndex:1];
|
||||
|
||||
//[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
||||
@@ -3460,7 +3575,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const int dim = op->op_params[0];
|
||||
const int max_period = op->op_params[1];
|
||||
@@ -3494,7 +3609,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_kargs_argmax args = {
|
||||
/*.ne00 = */ ne00,
|
||||
@@ -3535,7 +3650,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
|
||||
|
||||
@@ -3545,7 +3660,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
const int nptg = (ne00 + nth - 1)/nth;
|
||||
const int npr = (ne00 + nth - 1)/nth;
|
||||
|
||||
// Metal kernels require the buffer size to be multiple of 16 bytes
|
||||
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
||||
@@ -3557,7 +3672,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += ggml_nbytes(op);
|
||||
|
||||
if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
|
||||
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
|
||||
std::swap(bid_dst, bid_tmp);
|
||||
}
|
||||
|
||||
@@ -3579,7 +3694,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
|
||||
|
||||
@@ -3611,8 +3726,6 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
std::swap(bid_dst, bid_tmp);
|
||||
@@ -3632,7 +3745,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float slope;
|
||||
memcpy(&slope, op->op_params, sizeof(float));
|
||||
@@ -3668,7 +3781,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
|
||||
@@ -3704,7 +3817,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -197,6 +197,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
||||
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
} break;
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
res *= 2;
|
||||
|
||||
@@ -1832,6 +1832,117 @@ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_blk(
|
||||
constant ggml_metal_kargs_cumsum_blk & args,
|
||||
device const char * src0,
|
||||
device char * tmp,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ib = tgpig[0]/args.ne01;
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * src0_row = (device const float *) (src0 +
|
||||
args.nb01*i01 +
|
||||
args.nb02*i02 +
|
||||
args.nb03*i03);
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
|
||||
float v = 0.0f;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
v = src0_row[i00 + tpitg.x];
|
||||
}
|
||||
|
||||
float s = simd_prefix_inclusive_sum(v);
|
||||
|
||||
if (tiisg == N_SIMDWIDTH - 1) {
|
||||
shmem_f32[sgitg] = s;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
s += shmem_f32[sgitg];
|
||||
|
||||
device float * dst_row = (device float *) dst +
|
||||
args.ne00*i01 +
|
||||
args.ne00*args.ne01*i02 +
|
||||
args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
dst_row[i00 + tpitg.x] = s;
|
||||
}
|
||||
|
||||
if (args.outb && tpitg.x == ntg.x - 1) {
|
||||
device float * tmp_row = (device float *) tmp +
|
||||
args.net0*i01 +
|
||||
args.net0*args.net1*i02 +
|
||||
args.net0*args.net1*args.net2*i03;
|
||||
|
||||
tmp_row[ib] = s;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_add(
|
||||
constant ggml_metal_kargs_cumsum_add & args,
|
||||
device const char * tmp,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ib = tgpig[0]/args.ne01;
|
||||
|
||||
if (ib == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * tmp_row = (device const float *) (tmp +
|
||||
args.nbt1*i01 +
|
||||
args.nbt2*i02 +
|
||||
args.nbt3*i03);
|
||||
|
||||
device float * dst_row = (device float *) dst +
|
||||
args.ne00*i01 +
|
||||
args.ne00*args.ne01*i02 +
|
||||
args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
@@ -4543,7 +4654,7 @@ typedef void (argsort_t)(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]);
|
||||
@@ -4553,7 +4664,7 @@ kernel void kernel_argsort_f32_i32(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
@@ -4565,10 +4676,10 @@ kernel void kernel_argsort_f32_i32(
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
||||
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
||||
|
||||
// initialize indices
|
||||
smem_i32[col] = i00 + col;
|
||||
shmem_i32[col] = i00 + col;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -4577,20 +4688,20 @@ kernel void kernel_argsort_f32_i32(
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (smem_i32[col] >= args.ne00 ||
|
||||
(smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
|
||||
x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
|
||||
if (shmem_i32[col] >= args.ne00 ||
|
||||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
|
||||
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(smem_i32[col], smem_i32[ixj]);
|
||||
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (smem_i32[ixj] >= args.ne00 ||
|
||||
(smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
|
||||
x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
|
||||
if (shmem_i32[ixj] >= args.ne00 ||
|
||||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
|
||||
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(smem_i32[col], smem_i32[ixj]);
|
||||
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4603,7 +4714,7 @@ kernel void kernel_argsort_f32_i32(
|
||||
if (i00 + col < args.ne00) {
|
||||
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
dst[col] = smem_i32[col];
|
||||
dst[col] = shmem_i32[col];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4628,12 +4739,13 @@ kernel void kernel_argsort_merge_f32_i32(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int im = tgpig[0] / args.ne01;
|
||||
int i01 = tgpig[0] % args.ne01;
|
||||
int i02 = tgpig[1];
|
||||
int i03 = tgpig[2];
|
||||
|
||||
const int start = im * (2*args.len);
|
||||
const int im = tgpig[0] / args.ne01;
|
||||
const int i01 = tgpig[0] % args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
const int start = im * (2 * args.len);
|
||||
|
||||
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
|
||||
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
|
||||
@@ -4657,54 +4769,101 @@ kernel void kernel_argsort_merge_f32_i32(
|
||||
+ args.nb02*i02
|
||||
+ args.nb03*i03);
|
||||
|
||||
for (int k = tpitg.x; k < (int) total; k += ntg.x) {
|
||||
// find partition (i,j) such that i+j = k
|
||||
int low = k > len1 ? k - len1 : 0;
|
||||
int high = MIN(k, len0);
|
||||
if (total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
while (low < high) {
|
||||
const int mid = (low + high) >> 1;
|
||||
const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||
|
||||
const int32_t idx0 = tmp0[mid];
|
||||
const int32_t idx1 = tmp1[k - mid - 1];
|
||||
const int k0 = tpitg.x * chunk;
|
||||
const int k1 = min(k0 + chunk, total);
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
if (k0 >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (val0 <= val1) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
} else {
|
||||
if (val0 >= val1) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
}
|
||||
int low = k0 > len1 ? k0 - len1 : 0;
|
||||
int high = MIN(k0, len0);
|
||||
|
||||
// binary-search partition (i, j) such that i + j = k
|
||||
while (low < high) {
|
||||
const int mid = (low + high) >> 1;
|
||||
|
||||
const int32_t idx0 = tmp0[mid];
|
||||
const int32_t idx1 = tmp1[k0 - mid - 1];
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
|
||||
bool take_left;
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
const int i = low;
|
||||
const int j = k - i;
|
||||
if (take_left) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
}
|
||||
|
||||
int i = low;
|
||||
int j = k0 - i;
|
||||
|
||||
// keep the merge fronts into registers
|
||||
int32_t idx0 = 0;
|
||||
float val0 = 0.0f;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
|
||||
int32_t idx1 = 0;
|
||||
float val1 = 0.0f;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
|
||||
for (int k = k0; k < k1; ++k) {
|
||||
int32_t out_idx;
|
||||
|
||||
if (i >= len0) {
|
||||
out_idx = tmp1[j];
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp1[j++];
|
||||
}
|
||||
break;
|
||||
} else if (j >= len1) {
|
||||
out_idx = tmp0[i];
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp0[i++];
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
const int32_t idx0 = tmp0[i];
|
||||
const int32_t idx1 = tmp1[j];
|
||||
bool take_left;
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
out_idx = (order == GGML_SORT_ORDER_ASC)
|
||||
? (val0 <= val1 ? idx0 : idx1)
|
||||
: (val0 >= val1 ? idx0 : idx1);
|
||||
if (take_left) {
|
||||
out_idx = idx0;
|
||||
++i;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
} else {
|
||||
out_idx = idx1;
|
||||
++j;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst[k] = out_idx;
|
||||
@@ -6401,6 +6560,7 @@ template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_
|
||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
|
||||
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
|
||||
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
|
||||
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
|
||||
#endif
|
||||
|
||||
@@ -119,6 +119,7 @@ set(GGML_OPENCL_KERNELS
|
||||
pad
|
||||
repeat
|
||||
mul_mat_f16_f32
|
||||
mul_mm_f16_f32_kq_kqv
|
||||
conv2d
|
||||
conv2d_f16_f32
|
||||
flash_attn_f32_f16
|
||||
|
||||
@@ -407,6 +407,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program_mul_mv_f32_f32;
|
||||
cl_program program_mul;
|
||||
cl_program program_mul_mat_f16_f32_tiled;
|
||||
cl_program program_mul_mm_f16_f32_kqv;
|
||||
cl_program program_mul_mm_f16_f32_kq;
|
||||
cl_program program_div;
|
||||
cl_program program_sub;
|
||||
cl_program program_norm;
|
||||
@@ -481,6 +483,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mat_f16_f32;
|
||||
cl_kernel kernel_mul_mat_f16_f32_l4;
|
||||
cl_kernel kernel_mul_mat_f16_f32_tiled;
|
||||
cl_kernel kernel_mul_mm_f16_f32_kqv;
|
||||
cl_kernel kernel_mul_mm_f16_f32_kq;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
|
||||
@@ -1235,6 +1239,25 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_f16_f32_kq_kqv
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_f16_f32_kq_kqv.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_f16_f32_kq_kqv.cl");
|
||||
#endif
|
||||
backend_ctx->program_mul_mm_f16_f32_kqv =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+" -DKQV ");
|
||||
backend_ctx->program_mul_mm_f16_f32_kq =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, "mul_mm_f16_f32_kqv", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, "mul_mm_f16_f32_kq", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -5682,7 +5705,7 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
|
||||
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
|
||||
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs, NULL));
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
@@ -6665,6 +6688,146 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
|
||||
const cl_ulong nb01 = src0->nb[1];
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
|
||||
const int ne10 = src1->ne[0];
|
||||
const int ne11 = src1->ne[1];
|
||||
const int ne12 = src1->ne[2];
|
||||
|
||||
const cl_ulong nb10 = src1->nb[0];
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
|
||||
cl_kernel kernel;
|
||||
cl_context context = backend_ctx->context;
|
||||
|
||||
cl_int status;
|
||||
cl_image_format img_fmt_1d;
|
||||
cl_image_desc img_desc_1d;
|
||||
cl_buffer_region region;
|
||||
cl_mem A_image1d;
|
||||
cl_mem A_sub_buffer;
|
||||
cl_mem B_sub_buffer;
|
||||
cl_mem D_image1d;
|
||||
cl_mem D_sub_buffer;
|
||||
|
||||
int M = ne01;
|
||||
int N = ne1;
|
||||
int K = ne00;
|
||||
|
||||
if (nb01 > nb02) {
|
||||
// KQ
|
||||
kernel = backend_ctx->kernel_mul_mm_f16_f32_kq;
|
||||
} else {
|
||||
// KQV
|
||||
kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv;
|
||||
}
|
||||
// create sub-buffer for A
|
||||
// <--------------------------------------------> //
|
||||
extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra;
|
||||
|
||||
region.origin = (extra0->offset);
|
||||
if (nb01 > nb02) {
|
||||
// KQ
|
||||
region.size = nb01 * ne01;
|
||||
} else {
|
||||
// KQV
|
||||
region.size = nb02 * ne02;
|
||||
}
|
||||
|
||||
A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// <--------------------------------------------> //
|
||||
|
||||
// create sub-buffer for B
|
||||
// <--------------------------------------------> //
|
||||
region.origin = (extra1->offset);
|
||||
region.size = nb10 * ne10 * ne11 * ne12;
|
||||
B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
// <--------------------------------------------> //
|
||||
|
||||
img_fmt_1d = {CL_RGBA, CL_FLOAT};
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
if (nb01 > nb02) {
|
||||
img_desc_1d.image_width = (nb01 * ne01 / 4)/4;
|
||||
}
|
||||
else {
|
||||
img_desc_1d.image_width = (nb02 * ne02 / 4)/4;
|
||||
}
|
||||
img_desc_1d.buffer = A_sub_buffer;
|
||||
A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// create sub-buffer for output C
|
||||
// <--------------------------------------------> //
|
||||
region.origin = (extrad->offset);
|
||||
region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes
|
||||
D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
// <--------------------------------------------> //
|
||||
|
||||
// create image for C output
|
||||
// <--------------------------------------------> //
|
||||
img_fmt_1d = {CL_R, CL_FLOAT};
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4;
|
||||
img_desc_1d.buffer = D_sub_buffer;
|
||||
D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
// <--------------------------------------------> //
|
||||
|
||||
int offset_src0 = 0;
|
||||
int offset_src1 = 0;
|
||||
|
||||
// set kernel args
|
||||
// <--------------------------------------------> //
|
||||
cl_uint k_arg = 0;
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src0));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src1));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01));
|
||||
|
||||
size_t global_work_size[3] = {64, static_cast<size_t>(((M+63)/64)), static_cast<size_t>(((N+31)/32)*ne12)};
|
||||
size_t local_work_size[3] = {64, 1, 2};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
|
||||
// deallocate sub buffers and images
|
||||
// <--------------------------------------------> //
|
||||
CL_CHECK(clReleaseMemObject(A_image1d));
|
||||
CL_CHECK(clReleaseMemObject(D_image1d));
|
||||
CL_CHECK(clReleaseMemObject(A_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(B_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(D_sub_buffer));
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -6731,6 +6894,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
cl_context context = backend_ctx->context;
|
||||
|
||||
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){
|
||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) {
|
||||
|
||||
// init CL objects
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
#define LM_FIRST_256B 0
|
||||
#define LM_SECOND_256B 64
|
||||
#define LM_THIRD_256B 128
|
||||
#define LM_FOURTH_256B 192
|
||||
|
||||
|
||||
inline float16 mm_load_a(
|
||||
image1d_buffer_t matrix_A,
|
||||
uint subMatrixAStartInElements,
|
||||
int nb01,
|
||||
int line_stride_matrix_A_in_bytes
|
||||
) {
|
||||
__private float8 regA;
|
||||
size_t sub_block_id_m = get_local_id(0);
|
||||
|
||||
#ifdef KQV
|
||||
uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4);
|
||||
#else // KQ
|
||||
uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4);
|
||||
#endif
|
||||
|
||||
regA.s0123 = read_imagef(matrix_A, a_texCoord/4);
|
||||
regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4);
|
||||
|
||||
return convert_float16(as_half16(regA));
|
||||
}
|
||||
|
||||
inline float4 alu_32(
|
||||
float16 regA,
|
||||
__local float4* matrix_B_vec
|
||||
) {
|
||||
|
||||
__private float4 rC = 0;
|
||||
int i = get_sub_group_id() * 64;
|
||||
|
||||
rC += regA.s0 * matrix_B_vec[i];
|
||||
rC += regA.s1 * matrix_B_vec[i + 16];
|
||||
rC += regA.s4 * matrix_B_vec[i + 1];
|
||||
rC += regA.s5 * matrix_B_vec[i + 17];
|
||||
rC += regA.s8 * matrix_B_vec[i + 2];
|
||||
rC += regA.s9 * matrix_B_vec[i + 18];
|
||||
rC += regA.sc * matrix_B_vec[i + 3];
|
||||
rC += regA.sd * matrix_B_vec[i + 19];
|
||||
|
||||
i += 32;
|
||||
|
||||
rC += regA.s2 * matrix_B_vec[i];
|
||||
rC += regA.s3 * matrix_B_vec[i + 16];
|
||||
rC += regA.s6 * matrix_B_vec[i + 1];
|
||||
rC += regA.s7 * matrix_B_vec[i + 17];
|
||||
rC += regA.sa * matrix_B_vec[i + 2];
|
||||
rC += regA.sb * matrix_B_vec[i + 18];
|
||||
rC += regA.se * matrix_B_vec[i + 3];
|
||||
rC += regA.sf * matrix_B_vec[i + 19];
|
||||
|
||||
return rC;
|
||||
}
|
||||
|
||||
inline float16 alu_16(
|
||||
float16 regA,
|
||||
__local float* matrix_B_local
|
||||
) {
|
||||
float16 out;
|
||||
__local float4* matrix_B_vec = (__local float4*)matrix_B_local;
|
||||
|
||||
out.s0123 = alu_32(regA, matrix_B_vec);
|
||||
out.s4567 = alu_32(regA, matrix_B_vec + 4);
|
||||
out.s89ab = alu_32(regA, matrix_B_vec + 8);
|
||||
out.scdef = alu_32(regA, matrix_B_vec + 12);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
inline void mm_mad(
|
||||
__local float* matrix_B_local,
|
||||
float16 regA,
|
||||
float8 regB,
|
||||
uint b_localOffsetInWords,
|
||||
float16* regC0_ptr,
|
||||
float16* regC1_ptr
|
||||
) {
|
||||
int offset = b_localOffsetInWords + get_sub_group_id() * 256;
|
||||
|
||||
matrix_B_local[offset + LM_FIRST_256B] = regB.s0;
|
||||
matrix_B_local[offset + LM_SECOND_256B] = regB.s1;
|
||||
matrix_B_local[offset + LM_THIRD_256B] = regB.s2;
|
||||
matrix_B_local[offset + LM_FOURTH_256B] = regB.s3;
|
||||
|
||||
float16 add0 = alu_16(regA, matrix_B_local);
|
||||
*regC0_ptr += add0;
|
||||
|
||||
matrix_B_local[offset + LM_FIRST_256B] = regB.s4;
|
||||
matrix_B_local[offset + LM_SECOND_256B] = regB.s5;
|
||||
matrix_B_local[offset + LM_THIRD_256B] = regB.s6;
|
||||
matrix_B_local[offset + LM_FOURTH_256B] = regB.s7;
|
||||
|
||||
float16 add1 = alu_16(regA, matrix_B_local);
|
||||
*regC1_ptr += add1;
|
||||
}
|
||||
|
||||
inline void mm_store_c_N(
|
||||
__write_only image1d_buffer_t matrix_C,
|
||||
float16 regC0,
|
||||
float16 regC1,
|
||||
uint subMatrixCStartInElements,
|
||||
int line_stride_matrix_C_in_bytes,
|
||||
int mask
|
||||
) {
|
||||
size_t sub_block_id_m = get_local_id(0);
|
||||
|
||||
uint strideInWords = line_stride_matrix_C_in_bytes/4;
|
||||
uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m);
|
||||
|
||||
uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords;
|
||||
uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords;
|
||||
uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords;
|
||||
uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords;
|
||||
uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords;
|
||||
uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords;
|
||||
uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords;
|
||||
uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords;
|
||||
uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords;
|
||||
uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords;
|
||||
uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords;
|
||||
uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords;
|
||||
uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords;
|
||||
uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords;
|
||||
uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords;
|
||||
uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords;
|
||||
uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords;
|
||||
uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords;
|
||||
uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords;
|
||||
uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords;
|
||||
uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords;
|
||||
uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords;
|
||||
uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords;
|
||||
uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords;
|
||||
uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords;
|
||||
uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords;
|
||||
uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords;
|
||||
uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords;
|
||||
uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords;
|
||||
uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords;
|
||||
uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords;
|
||||
|
||||
if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); }
|
||||
if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); }
|
||||
if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); }
|
||||
if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); }
|
||||
if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); }
|
||||
if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); }
|
||||
if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); }
|
||||
if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); }
|
||||
if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); }
|
||||
if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); }
|
||||
if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); }
|
||||
if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); }
|
||||
if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); }
|
||||
if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); }
|
||||
if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); }
|
||||
if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); }
|
||||
if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); }
|
||||
if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); }
|
||||
if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); }
|
||||
if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); }
|
||||
if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); }
|
||||
if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); }
|
||||
if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); }
|
||||
if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); }
|
||||
if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); }
|
||||
if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); }
|
||||
if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); }
|
||||
if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); }
|
||||
if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); }
|
||||
if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); }
|
||||
if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); }
|
||||
if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); }
|
||||
}
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
#ifdef KQV
|
||||
__kernel void mul_mm_f16_f32_kqv(
|
||||
#else
|
||||
__kernel void mul_mm_f16_f32_kq(
|
||||
#endif
|
||||
__read_only image1d_buffer_t matrix_A,
|
||||
int offset0,
|
||||
__global float* matrix_B,
|
||||
int offset1,
|
||||
__write_only image1d_buffer_t matrix_C,
|
||||
int offsetd,
|
||||
int M, int K, int N,
|
||||
int D_A,
|
||||
int D_B,
|
||||
int nb01
|
||||
) {
|
||||
|
||||
uint block_id_m = get_global_id(1);
|
||||
uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N);
|
||||
uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N);
|
||||
|
||||
__private float16 regA;
|
||||
__private float8 regB;
|
||||
__private float16 regC0;
|
||||
__private float16 regC1;
|
||||
|
||||
const uint col = block_id_m * TILESIZE_M;
|
||||
const uint row = block_id_n * TILESIZE_N;
|
||||
const uint depth_A = block_id_d / (D_B/D_A);
|
||||
const uint depth_B = block_id_d;
|
||||
|
||||
#ifdef KQV
|
||||
int line_stride_matrix_A_in_bytes = nb01 * M;
|
||||
int line_stride_matrix_B_in_bytes = K * N * 4;
|
||||
#else
|
||||
int line_stride_matrix_A_in_bytes = K * D_A * 2;
|
||||
int line_stride_matrix_B_in_bytes = K * D_B * 4;
|
||||
#endif
|
||||
|
||||
int line_stride_matrix_C_in_bytes = M * 4;
|
||||
|
||||
const uint strideAinElements = line_stride_matrix_A_in_bytes / 2;
|
||||
const uint strideBinElements = line_stride_matrix_B_in_bytes / 4;
|
||||
|
||||
size_t sub_block_id_m = get_local_id(0);
|
||||
|
||||
uint b_localOffsetInWords = (sub_block_id_m/16)*16
|
||||
+ ((((sub_block_id_m)>>0)&1)<<2)
|
||||
+ ((((sub_block_id_m)>>1)&1)<<3)
|
||||
+ ((((sub_block_id_m)>>2)&1)<<0)
|
||||
+ ((((sub_block_id_m)>>3)&1)<<1);
|
||||
|
||||
uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)};
|
||||
uint b_globalOffsetInWords00, b_globalOffsetInWords16;
|
||||
#ifdef KQV
|
||||
b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K;
|
||||
b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K);
|
||||
uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2;
|
||||
uint subMatrixBStartInElements = depth_B * strideBinElements + row * K;
|
||||
#else
|
||||
b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4;
|
||||
b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4);
|
||||
uint subMatrixAStartInElements = col * strideAinElements + depth_A * K;
|
||||
uint subMatrixBStartInElements = row * strideBinElements + depth_B * K;
|
||||
#endif
|
||||
|
||||
__local float matrix_B_local[1024];
|
||||
|
||||
for (uint step=0; step < K; step+=TILESIZE_K) {
|
||||
size_t sub_block_id_m = get_local_id(0);
|
||||
regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes);
|
||||
|
||||
uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00;
|
||||
uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16;
|
||||
|
||||
regB.s0123 = vload4(b_coordInWords00/4, matrix_B);
|
||||
regB.s4567 = vload4(b_coordInWords16/4, matrix_B);
|
||||
|
||||
mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, ®C0, ®C1);
|
||||
|
||||
subMatrixAStartInElements += TILESIZE_K;
|
||||
subMatrixBStartInElements += TILESIZE_K;
|
||||
}
|
||||
|
||||
uint subMatrixCStartInElements = depth_B * N * M + row * M + col;
|
||||
mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32));
|
||||
}
|
||||
|
||||
@@ -134,6 +134,15 @@ kernel void kernel_rms_norm_mul(
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
// The size of sum is sizeof(float)*subgroup_size.
|
||||
// Each subgroup writes its partial sum to this array.
|
||||
// So the number of subgroups per workgroup for this kernel cannot exceed the subgroup size.
|
||||
// This is generally true -
|
||||
// for subgroup size 64, workgroup size should be less than 4096 (the max is usually 1024).
|
||||
if (get_sub_group_id() == 0) {
|
||||
sum[get_sub_group_local_id()] = 0.0f;
|
||||
}
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0);
|
||||
@@ -148,24 +157,30 @@ kernel void kernel_rms_norm_mul(
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
sum[get_sub_group_id()] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
|
||||
if (get_local_id(0) < i) {
|
||||
sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
||||
}
|
||||
}
|
||||
if (get_local_id(0) == 0) {
|
||||
sum[0] /= ne00;
|
||||
}
|
||||
//for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
|
||||
// if (get_local_id(0) < i) {
|
||||
// sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
||||
// }
|
||||
//}
|
||||
//if (get_local_id(0) == 0) {
|
||||
// sum[0] /= ne00;
|
||||
//}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
//barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
float mean = sum[0];
|
||||
sumf = sum[get_sub_group_local_id()];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
float mean = sumf / ne00;
|
||||
float scale = 1.0f/sqrt(mean + eps);
|
||||
|
||||
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
+111
-249
@@ -170,73 +170,31 @@ static __dpct_inline__ T op_trunc(T x) {
|
||||
return sycl::trunc(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_sgn(x[i]);
|
||||
}
|
||||
}
|
||||
template<typename T, typename F>
|
||||
static void unary_op_generic_kernel(
|
||||
const T * x,
|
||||
T * dst,
|
||||
const int k,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,
|
||||
const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
const size_t nbd0, const size_t nbd1, const size_t nbd2, const size_t nbd3,
|
||||
const sycl::nd_item<1> & item_ct1,
|
||||
F func) {
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
(void) ne3;
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_abs(x[i]);
|
||||
}
|
||||
}
|
||||
const int64_t i0 = i % ne0;
|
||||
const int64_t i1 = (i / ne0) % ne1;
|
||||
const int64_t i2 = (i / (ne0*ne1)) % ne2;
|
||||
const int64_t i3 = i / (ne0*ne1*ne2);
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_elu(x[i]);
|
||||
}
|
||||
}
|
||||
const char * src_base = (const char *) x;
|
||||
char * dst_base = (char *) dst;
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_gelu(x[i]);
|
||||
}
|
||||
}
|
||||
const T * srcp = (const T *)(src_base + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 );
|
||||
T * dstp = (T *)(dst_base + i0*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3);
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_silu(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_gelu_quick(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_gelu_erf(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_tanh(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_relu(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_sigmoid(x[i]);
|
||||
*dstp = func(*srcp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,27 +219,6 @@ static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::n
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_hardsigmoid(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_hardswish(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_exp(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
@@ -289,19 +226,6 @@ static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::n
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_neg(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_step(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
|
||||
@@ -620,6 +544,48 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
||||
}
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
static inline void ggml_sycl_op_unary(
|
||||
ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {
|
||||
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
|
||||
const size_t nb0 = src0->nb[0];
|
||||
const size_t nb1 = src0->nb[1];
|
||||
const size_t nb2 = src0->nb[2];
|
||||
const size_t nb3 = src0->nb[3];
|
||||
|
||||
const size_t nbd0 = dst->nb[0];
|
||||
const size_t nbd1 = dst->nb[1];
|
||||
const size_t nbd2 = dst->nb[2];
|
||||
const size_t nbd3 = dst->nb[3];
|
||||
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[=](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_generic_kernel(
|
||||
src, dst_ptr, k_elements,
|
||||
ne0, ne1, ne2, ne3,
|
||||
nb0, nb1, nb2, nb3,
|
||||
nbd0, nbd1, nbd2, nbd3,
|
||||
item_ct1,
|
||||
func
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
@@ -645,159 +611,75 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_sgn(x);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_abs(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_elu(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_silu(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_gelu(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_gelu_quick(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_gelu_erf(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_tanh(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_relu(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_hardsigmoid(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_hardswish(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_exp(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
@@ -814,42 +696,22 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_neg(x);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_step(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
|
||||
return op_sigmoid(x);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -4360,21 +4360,22 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
}
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
return true;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,21 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[i] = D_TYPE(abs(float(data_a[i])));
|
||||
}
|
||||
@@ -7,6 +7,7 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
#include "types.glsl"
|
||||
#include "flash_attn_base.glsl"
|
||||
@@ -108,6 +109,38 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
float Sf[Br][cols_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
@@ -153,21 +186,6 @@ void main() {
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
|
||||
@@ -148,6 +149,37 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
float mask_cache[Bc * Br / WorkGroupSize];
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
mask_cache[idx / WorkGroupSize] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
@@ -208,7 +240,8 @@ void main() {
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
float f = mask_cache[idx / WorkGroupSize];
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||
return max(x, y);
|
||||
}
|
||||
|
||||
float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
|
||||
return max(x, y);
|
||||
}
|
||||
|
||||
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||
return x;
|
||||
}
|
||||
@@ -142,6 +146,44 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
if (nem1_bounds_check) {
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
@@ -158,31 +200,7 @@ void main() {
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
if (nem1_bounds_check) {
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
}
|
||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||
}
|
||||
|
||||
// Clear padding elements to -inf, so they don't contribute to rowmax
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#version 450
|
||||
|
||||
#include "rte.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));
|
||||
}
|
||||
@@ -11,29 +11,7 @@
|
||||
#define EXPERT_COUNT 8
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
#ifndef MMQ
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#else
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
#ifdef B_TYPE_VEC2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||
#endif
|
||||
#ifdef B_TYPE_VEC4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 4) readonly buffer IDS {int data_ids[];};
|
||||
#endif
|
||||
#include "mul_mat_vec_iface.glsl"
|
||||
|
||||
#include "dequant_funcs.glsl"
|
||||
|
||||
@@ -48,8 +26,7 @@ layout (push_constant) uniform parameter
|
||||
uint batch_stride_b;
|
||||
uint batch_stride_d;
|
||||
|
||||
uint enable_bias;
|
||||
uint enable_scale;
|
||||
uint fusion_flags;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint nei0;
|
||||
@@ -123,17 +100,24 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
if (p.enable_bias != 0) {
|
||||
#ifdef MUL_MAT_ID
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
|
||||
#else
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
#ifdef MUL_MAT_ID
|
||||
if (p.enable_scale != 0) {
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
#endif
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
@@ -171,17 +155,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
temp[j][n] += tmpsh[j][n][s];
|
||||
}
|
||||
if (p.enable_bias != 0) {
|
||||
#ifdef MUL_MAT_ID
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
|
||||
#else
|
||||
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
#ifdef MUL_MAT_ID
|
||||
if (p.enable_scale != 0) {
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
#endif
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
@@ -209,17 +200,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
if (p.enable_bias != 0) {
|
||||
#ifdef MUL_MAT_ID
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
|
||||
#else
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
#endif
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
#ifdef MUL_MAT_ID
|
||||
if (p.enable_scale != 0) {
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
|
||||
}
|
||||
#endif
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
#include "types.glsl"
|
||||
|
||||
#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
|
||||
#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
|
||||
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
|
||||
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
|
||||
|
||||
#ifndef MMQ
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(A_TYPE_VEC4)
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
#endif
|
||||
#else
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
#ifdef B_TYPE_VEC2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||
#endif
|
||||
#ifdef B_TYPE_VEC4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Fuse0 {D_TYPE data_fuse0[];};
|
||||
layout (binding = 4) readonly buffer Fuse1 {D_TYPE data_fuse1[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 5) readonly buffer IDS {int data_ids[];};
|
||||
#endif
|
||||
|
||||
@@ -8,14 +8,7 @@
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
||||
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
|
||||
#include "mul_mat_vec_iface.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
@@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
|
||||
uint nb03;
|
||||
uint nb13;
|
||||
uint nb23;
|
||||
uint enable_bias;
|
||||
uint fusion_flags;
|
||||
} p;
|
||||
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
@@ -120,9 +113,12 @@ void main() {
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
if (p.enable_bias != 0) {
|
||||
tmp[0] += FLOAT_TYPE(data_bias[idst]);
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
tmp[0] += FLOAT_TYPE(data_fuse0[idst]);
|
||||
}
|
||||
dst[idst] = tmp[0];
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
tmp[0] += FLOAT_TYPE(data_fuse1[idst]);
|
||||
}
|
||||
data_d[idst] = tmp[0];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,14 +10,7 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
||||
|
||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
|
||||
#include "mul_mat_vec_iface.glsl"
|
||||
|
||||
layout(constant_id = 0) const int BLOCK_SIZE = 32;
|
||||
// gqa_ratio is in the range [1,8]
|
||||
@@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
|
||||
uint nchannels_y;
|
||||
uint b_offset;
|
||||
uint d_offset;
|
||||
uint enable_bias;
|
||||
uint fusion_flags;
|
||||
} p;
|
||||
|
||||
#if !USE_SUBGROUP_ADD
|
||||
@@ -151,10 +144,13 @@ void main() {
|
||||
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
|
||||
// dst is not transposed and not permuted
|
||||
const uint idst = (channel + c)*nrows_dst + row_dst;
|
||||
if (p.enable_bias != 0) {
|
||||
temp[c] += FLOAT_TYPE(data_bias[idst]);
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
temp[c] += FLOAT_TYPE(data_fuse0[idst]);
|
||||
}
|
||||
dst[idst] = temp[c];
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
|
||||
temp[c] += FLOAT_TYPE(data_fuse1[idst]);
|
||||
}
|
||||
data_d[idst] = temp[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
|
||||
buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,21 +345,22 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
|
||||
// Repack 2x4 quants into one int
|
||||
// Add the 3rd bit instead of subtracting it to allow packing the quants
|
||||
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||
// vec4 for unpack8 used due to #12147
|
||||
const i8vec2 vals00 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
||||
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
||||
const i8vec2 vals01 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
||||
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
||||
const i8vec2 vals10 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
||||
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
||||
const i8vec2 vals11 = unpack8(int32_t(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303)))).xy |
|
||||
unpack8(int32_t(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101))) << 2)).xy;
|
||||
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
|
||||
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
|
||||
const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
|
||||
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
|
||||
}
|
||||
@@ -516,15 +517,15 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
|
||||
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
|
||||
|
||||
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||
const i8vec2 vals00 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))).xy |
|
||||
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
|
||||
const i8vec2 vals01 = (unpack8(int32_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))).xy |
|
||||
unpack8(int32_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4)).xy) - int8_t(32);
|
||||
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
|
||||
|
||||
if (iqs == 0) {
|
||||
const uint is = iqs_k / 4;
|
||||
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
|
||||
const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
|
||||
|
||||
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
data_d[i] = D_TYPE(-float(data_a[i]));
|
||||
}
|
||||
@@ -816,6 +816,9 @@ void process_shaders() {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
|
||||
|
||||
string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
}
|
||||
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
@@ -827,6 +830,8 @@ void process_shaders() {
|
||||
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("neg_f16", "neg.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("neg_f32", "neg.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
@@ -835,6 +840,8 @@ void process_shaders() {
|
||||
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
|
||||
@@ -409,6 +409,7 @@ class MODEL_ARCH(IntEnum):
|
||||
BAILINGMOE2 = auto()
|
||||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
AFMOE = auto()
|
||||
ERNIE4_5 = auto()
|
||||
ERNIE4_5_MOE = auto()
|
||||
HUNYUAN_MOE = auto()
|
||||
@@ -464,6 +465,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_POST_NORM = auto()
|
||||
ATTN_ROT_EMBD = auto()
|
||||
ATTN_SINKS = auto()
|
||||
ATTN_GATE = auto()
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
@@ -776,6 +778,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
|
||||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
MODEL_ARCH.AFMOE: "afmoe",
|
||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
|
||||
MODEL_ARCH.FALCON_H1: "falcon-h1",
|
||||
@@ -828,6 +831,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
|
||||
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
||||
@@ -2693,6 +2697,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.AFMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_GATE,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
MODEL_ARCH.ERNIE4_5: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@@ -314,6 +314,10 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.sinks", # openai-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_GATE: (
|
||||
"model.layers.{bid}.self_attn.gate_proj", # afmoe
|
||||
),
|
||||
|
||||
# Feed-forward norm
|
||||
MODEL_TENSOR.FFN_NORM: (
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||
@@ -340,11 +344,12 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.feedforward_layernorm", # apertus
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
# Pre feed-forward norm
|
||||
MODEL_TENSOR.FFN_PRE_NORM: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
|
||||
"model.layers.{bid}.pre_ff_layernorm.weight",
|
||||
"model.layers.{bid}.pre_mlp_layernorm", # afmoe
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
@@ -370,6 +375,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
||||
"model.layers.{bid}.mlp.router.gate", # afmoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
@@ -380,6 +386,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
||||
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
||||
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
|
||||
"model.layers.{bid}.mlp.expert_bias", # afmoe
|
||||
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
|
||||
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
|
||||
),
|
||||
|
||||
@@ -35,6 +35,7 @@ add_library(llama
|
||||
unicode-data.cpp
|
||||
unicode.cpp
|
||||
unicode.h
|
||||
models/afmoe.cpp
|
||||
models/apertus.cpp
|
||||
models/arcee.cpp
|
||||
models/arctic.cpp
|
||||
|
||||
@@ -90,6 +90,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_AFMOE, "afmoe" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
@@ -333,6 +334,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_AFMOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA4,
|
||||
{
|
||||
@@ -2444,6 +2475,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
||||
@@ -94,6 +94,7 @@ enum llm_arch {
|
||||
LLM_ARCH_BAILINGMOE2,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_AFMOE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_ERNIE4_5_MOE,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
@@ -312,6 +313,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_ATTN_SINKS,
|
||||
LLM_TENSOR_ATTN_GATE,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
|
||||
@@ -84,6 +84,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_15B: return "15B";
|
||||
case LLM_TYPE_16B: return "16B";
|
||||
case LLM_TYPE_20B: return "20B";
|
||||
case LLM_TYPE_26B: return "26B";
|
||||
case LLM_TYPE_27B: return "27B";
|
||||
case LLM_TYPE_30B: return "30B";
|
||||
case LLM_TYPE_32B: return "32B";
|
||||
@@ -695,6 +696,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
|
||||
// Set up interleaved sliding window attention (ISWA)
|
||||
// Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4)
|
||||
if (hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.set_swa_pattern(4);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
}
|
||||
|
||||
// Default to sigmoid if not set
|
||||
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 56: type = LLM_TYPE_6B; break;
|
||||
case 32: type = LLM_TYPE_26B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DECI:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@@ -5749,6 +5781,71 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
const int64_t n_expert_shared = hparams.n_expert_shared;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
// dual attention normalization
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// attention projections
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
// Q/K normalization
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
// attention gating
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
|
||||
// dual ffn normalization
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) {
|
||||
// MoE layers
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
|
||||
|
||||
// grouped expert weights
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
// shared expert
|
||||
if (n_expert_shared > 0) {
|
||||
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
}
|
||||
} else {
|
||||
// Dense layers
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_ERNIE4_5_MOE:
|
||||
{
|
||||
@@ -7243,6 +7340,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_arcee>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_afmoe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
{
|
||||
llm = std::make_unique<llm_build_ernie4_5>(*this, params);
|
||||
@@ -7528,6 +7629,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_MINIMAX_M2:
|
||||
case LLM_ARCH_COGVLM:
|
||||
case LLM_ARCH_PANGU_EMBED:
|
||||
case LLM_ARCH_AFMOE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
||||
@@ -76,6 +76,7 @@ enum llm_type {
|
||||
LLM_TYPE_15B,
|
||||
LLM_TYPE_16B,
|
||||
LLM_TYPE_20B,
|
||||
LLM_TYPE_26B,
|
||||
LLM_TYPE_27B,
|
||||
LLM_TYPE_30B,
|
||||
LLM_TYPE_32B,
|
||||
@@ -234,6 +235,7 @@ struct llama_layer {
|
||||
struct ggml_tensor * wk_enc = nullptr;
|
||||
struct ggml_tensor * wv_enc = nullptr;
|
||||
struct ggml_tensor * wo_enc = nullptr;
|
||||
struct ggml_tensor * wqkv_gate = nullptr;
|
||||
|
||||
// attention bias
|
||||
struct ggml_tensor * bq = nullptr;
|
||||
|
||||
+10
-5
@@ -4,6 +4,7 @@
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
@@ -1625,10 +1626,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
auto * ctx = new llama_sampler_grammar;
|
||||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
std::string trigger_pattern;
|
||||
llama_grammar * grammar = nullptr;
|
||||
// TODO: remove trigger_words support.
|
||||
if (trigger_words != nullptr && num_trigger_words > 0) {
|
||||
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
||||
std::string trigger_pattern("[\\s\\S]*?(");
|
||||
trigger_pattern = "[\\s\\S]*?(";
|
||||
for (size_t i = 0; i < num_trigger_words; ++i) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
if (i > 0) {
|
||||
@@ -1637,15 +1640,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
||||
}
|
||||
trigger_pattern += ")[\\s\\S]*";
|
||||
const auto * trigger_pattern_c = trigger_pattern.c_str();
|
||||
trigger_patterns = &trigger_pattern_c;
|
||||
num_trigger_patterns = 1;
|
||||
|
||||
std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
|
||||
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
|
||||
} else {
|
||||
grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
|
||||
}
|
||||
*ctx = {
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||
/* .grammar = */ grammar,
|
||||
};
|
||||
if (!ctx->grammar) {
|
||||
delete ctx;
|
||||
|
||||
@@ -443,6 +443,17 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_AFMOE:
|
||||
regex_exprs = {
|
||||
// Digit handling - uses custom implementation in unicode.cpp
|
||||
// Groups digits with leading 1-2 based on total length modulo 3
|
||||
"\\p{AFMoE_digits}",
|
||||
// CJK and Asian scripts (using direct Unicode literals)
|
||||
"[一-鿿㐀-䶿豈--ゟ゠-ヿ・-゚⼀-เ--ក-က-႟ꩠ-ꩿꧠ-가-ᄀ-ᇿ]+",
|
||||
// Main BPE pattern
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@@ -1993,6 +2004,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "grok-2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "afmoe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_AFMOE;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "minimax-m2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
|
||||
|
||||
@@ -50,6 +50,7 @@ enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
|
||||
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
|
||||
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// MuP scaling: embeddings * sqrt(hidden_size)
|
||||
// mup_enabled = true, hidden_size = 1024, scale = 32.0
|
||||
inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd)));
|
||||
cb(inpL, "inp_embd_scaled", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// dual attention normalization (pre)
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * attn_inp = cur; // save input for gate computation
|
||||
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// compute gate from input
|
||||
ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp);
|
||||
cb(gate, "attn_gate_proj", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Q/K normalization
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
// RoPE only for sliding_attention layers
|
||||
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
|
||||
((il + 1) % hparams.n_no_rope_layer_step) != 0;
|
||||
if (use_rope) {
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
}
|
||||
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
NULL, NULL, // wo will be applied after gating
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
|
||||
// attention gating: attn_out * sigmoid(gate) BEFORE o_proj
|
||||
gate = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate, "attn_gate_sig", il);
|
||||
cur = ggml_mul(ctx0, cur, gate);
|
||||
cb(cur, "attn_gated", il);
|
||||
|
||||
// now apply output projection
|
||||
cur = build_lora_mm(model.layers[il].wo, cur);
|
||||
cb(cur, "attn_o_proj", il);
|
||||
}
|
||||
|
||||
// dual attention normalization (post)
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// dual ffn normalization (pre)
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// MoE or dense FFN
|
||||
if ((uint32_t)il >= hparams.n_layer_dense_lead) {
|
||||
// MoE layer with sigmoid routing, normalization, and scaling
|
||||
ggml_tensor * moe_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU,
|
||||
hparams.expert_weights_norm, // norm_w (route_norm=True)
|
||||
hparams.expert_weights_scale, // scale_w
|
||||
hparams.expert_weights_scale, // w_scale (route_scale=2.826)
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// shared expert
|
||||
if (hparams.n_expert_shared > 0) {
|
||||
ggml_tensor * ffn_shexp = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = moe_out;
|
||||
}
|
||||
} else {
|
||||
// dense layer
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
// dual ffn normalization (post)
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
@@ -57,6 +57,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
int il) const;
|
||||
};
|
||||
|
||||
struct llm_build_afmoe : public llm_graph_context {
|
||||
llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_apertus : public llm_graph_context {
|
||||
llm_build_apertus(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
@@ -729,6 +729,80 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// AFMOE digit handling: splits digits with leading 1-2 based on total length modulo 3
|
||||
static std::vector<size_t> unicode_regex_split_custom_afmoe(const std::string & text, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
bpe_offsets.reserve(offsets.size());
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
const size_t offset_ini = start;
|
||||
const size_t offset_end = start + offset;
|
||||
assert(offset_end <= cpts.size());
|
||||
start = offset_end;
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
auto _add_token = [&] (const size_t end) -> size_t {
|
||||
assert(_prev_end <= end && end <= offset_end);
|
||||
size_t len = end - _prev_end;
|
||||
if (len > 0) {
|
||||
bpe_offsets.push_back(len);
|
||||
}
|
||||
_prev_end = end;
|
||||
return len;
|
||||
};
|
||||
|
||||
for (size_t pos = offset_ini; pos < offset_end; ) {
|
||||
const auto flags = _get_flags(pos);
|
||||
|
||||
// Handle digit sequences with special splitting logic
|
||||
if (flags.is_number) {
|
||||
size_t digit_start = pos;
|
||||
size_t digit_count = 0;
|
||||
|
||||
// Count consecutive digits
|
||||
while (_get_flags(pos).is_number && pos < offset_end) {
|
||||
digit_count++;
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Split based on total length modulo 3
|
||||
size_t remainder = digit_count % 3;
|
||||
size_t current = digit_start;
|
||||
|
||||
// Emit leading 1-2 digits if needed
|
||||
if (remainder > 0) {
|
||||
_add_token(current + remainder);
|
||||
current += remainder;
|
||||
}
|
||||
|
||||
// Emit groups of 3
|
||||
while (current < digit_start + digit_count) {
|
||||
_add_token(current + 3);
|
||||
current += 3;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// For non-digits, just move forward
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Add any remaining content
|
||||
if (_prev_end < offset_end) {
|
||||
_add_token(offset_end);
|
||||
}
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
|
||||
@@ -742,6 +816,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
||||
} else if (regex_expr == "\\p{Han}+") {
|
||||
// K2's first pattern - handle all K2 patterns together
|
||||
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
|
||||
} else if (regex_expr == "\\p{AFMoE_digits}") {
|
||||
// AFMOE digit pattern - use custom implementation for proper splitting
|
||||
bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
|
||||
@@ -5002,17 +5002,19 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
const bool b; // broadcast b matrix (only for use_id)
|
||||
const bool with_bias;
|
||||
const bool with_gate;
|
||||
std::array<int64_t, 2> batch_dims;
|
||||
|
||||
test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
|
||||
bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true)
|
||||
: type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) {
|
||||
bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true,
|
||||
std::array<int64_t, 2> batch_dims = {4, 2})
|
||||
: type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate), batch_dims(batch_dims) {
|
||||
if (use_id) {
|
||||
GGML_ASSERT(n_used <= n_mats);
|
||||
}
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate);
|
||||
return VARS_TO_STR12(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate, batch_dims);
|
||||
}
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
@@ -5038,8 +5040,8 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
if (!use_id) {
|
||||
const int channels = 4;
|
||||
const int samples = 2;
|
||||
const int channels = batch_dims[0];
|
||||
const int samples = batch_dims[1];
|
||||
std::array<int64_t, 4> ne = { k, m, channels, samples };
|
||||
std::array<int64_t, 4> ne0 = { k, n, channels, samples };
|
||||
|
||||
@@ -5062,6 +5064,11 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
}
|
||||
|
||||
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||
|
||||
std::array<int64_t, 4> bias2_ne = { out->ne[0], 1, channels, samples };
|
||||
ggml_tensor * bias2 = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias2_ne.data());
|
||||
out = ggml_add(ctx, out, bias2);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
} else {
|
||||
@@ -5089,6 +5096,11 @@ struct test_mul_mat_vec_fusion : public test_case {
|
||||
}
|
||||
|
||||
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||
|
||||
std::array<int64_t, 4> scale_ne { 1, out->ne[1], out->ne[2], out->ne[3] };
|
||||
ggml_tensor * scale = ggml_new_tensor(ctx, out->type, 4, scale_ne.data());
|
||||
out = ggml_mul(ctx, out, scale);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
}
|
||||
@@ -7546,7 +7558,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_arange());
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
test_cases.emplace_back(new test_cumsum());
|
||||
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 512, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1023, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2047, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 }));
|
||||
|
||||
test_cases.emplace_back(new test_xielu());
|
||||
|
||||
@@ -7645,6 +7670,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
|
||||
use_id, 16, 8, b, with_bias, with_gate));
|
||||
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
|
||||
use_id, 16, 8, b, with_bias, with_gate, {1, 1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+5
-12
@@ -224,7 +224,6 @@ static void clip_log_callback_default(enum ggml_log_level level, const char * te
|
||||
}
|
||||
|
||||
struct clip_logger_state {
|
||||
ggml_log_level verbosity_thold;
|
||||
ggml_log_callback log_callback;
|
||||
void * log_callback_user_data;
|
||||
};
|
||||
@@ -258,17 +257,11 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
#define LOG_TMPL(level, ...) \
|
||||
do { \
|
||||
if ((level) >= g_logger_state.verbosity_thold) { \
|
||||
clip_log_internal((level), __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
|
||||
#define LOG_INF(...) clip_log_internal(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) clip_log_internal(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) clip_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
#define LOG_DBG(...) clip_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_CNT(...) clip_log_internal(GGML_LOG_LEVEL_CONT, __VA_ARGS__)
|
||||
|
||||
//
|
||||
// cpp wrappers
|
||||
|
||||
+1
-3
@@ -24,8 +24,7 @@
|
||||
#include <array>
|
||||
#include <functional>
|
||||
|
||||
// TODO: allow to pass callback from user code
|
||||
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
|
||||
struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
|
||||
|
||||
enum ffn_op_type {
|
||||
FFN_GELU,
|
||||
@@ -3507,7 +3506,6 @@ struct clip_model_loader {
|
||||
};
|
||||
|
||||
struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
|
||||
g_logger_state.verbosity_thold = ctx_params.verbosity;
|
||||
clip_ctx * ctx_vision = nullptr;
|
||||
clip_ctx * ctx_audio = nullptr;
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ enum clip_flash_attn_type {
|
||||
|
||||
struct clip_context_params {
|
||||
bool use_gpu;
|
||||
enum ggml_log_level verbosity;
|
||||
enum clip_flash_attn_type flash_attn_type;
|
||||
int image_min_tokens;
|
||||
int image_max_tokens;
|
||||
|
||||
@@ -135,7 +135,6 @@ struct mtmd_cli_context {
|
||||
mparams.use_gpu = params.mmproj_use_gpu;
|
||||
mparams.print_timings = true;
|
||||
mparams.n_threads = params.cpuparams.n_threads;
|
||||
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params.flash_attn_type;
|
||||
mparams.image_min_tokens = params.image_min_tokens;
|
||||
mparams.image_max_tokens = params.image_max_tokens;
|
||||
@@ -277,6 +276,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
common_init();
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
|
||||
if (params.mmproj.path.empty()) {
|
||||
show_additional_info(argc, argv);
|
||||
@@ -285,7 +285,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
mtmd_cli_context ctx(params);
|
||||
LOG("%s: loading model: %s\n", __func__, params.model.path.c_str());
|
||||
LOG_INF("%s: loading model: %s\n", __func__, params.model.path.c_str());
|
||||
|
||||
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
||||
|
||||
|
||||
@@ -32,8 +32,65 @@
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb/stb_image.h"
|
||||
|
||||
#define LOG_INF(...) fprintf(stdout, __VA_ARGS__)
|
||||
#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__)
|
||||
//
|
||||
// internal logging functions
|
||||
//
|
||||
|
||||
struct mtmd_helper_logger {
|
||||
ggml_log_callback default_callback = [](ggml_log_level level, const char * text, void * user_data) {
|
||||
(void) level;
|
||||
(void) user_data;
|
||||
fputs(text, stderr);
|
||||
fflush(stderr);
|
||||
};
|
||||
|
||||
ggml_log_callback log_callback = default_callback;
|
||||
void * log_callback_user_data;
|
||||
|
||||
void log_v(enum ggml_log_level level, const char * format, va_list args) {
|
||||
if (format == NULL) {
|
||||
return;
|
||||
}
|
||||
va_list args_copy;
|
||||
va_copy(args_copy, args);
|
||||
char buffer[128];
|
||||
int len = vsnprintf(buffer, 128, format, args);
|
||||
if (len < 128) {
|
||||
log_callback(level, buffer, log_callback_user_data);
|
||||
} else {
|
||||
char * buffer2 = (char *) calloc(len + 1, sizeof(char));
|
||||
vsnprintf(buffer2, len + 1, format, args_copy);
|
||||
buffer2[len] = 0;
|
||||
log_callback(level, buffer2, log_callback_user_data);
|
||||
free(buffer2);
|
||||
}
|
||||
va_end(args_copy);
|
||||
}
|
||||
|
||||
void log(enum ggml_log_level level, const char * format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
log_v(level, format, args);
|
||||
va_end(args);
|
||||
}
|
||||
} g_logger;
|
||||
|
||||
#define LOG_INF(...) g_logger.log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) g_logger.log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) g_logger.log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
|
||||
void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||
if (log_callback == nullptr) {
|
||||
log_callback = g_logger.default_callback;
|
||||
}
|
||||
g_logger.log_callback = log_callback;
|
||||
g_logger.log_callback_user_data = user_data;
|
||||
mtmd_log_set(log_callback, user_data);
|
||||
}
|
||||
|
||||
//
|
||||
// helper functions
|
||||
//
|
||||
|
||||
size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
|
||||
size_t n_tokens = 0;
|
||||
@@ -325,7 +382,7 @@ int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
|
||||
llama_pos * new_n_past) {
|
||||
size_t n_chunks = mtmd_input_chunks_size(chunks);
|
||||
if (n_chunks == 0) {
|
||||
LOG_ERR("no chunks to eval\n");
|
||||
LOG_WRN("no chunks to eval\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,11 @@ extern "C" {
|
||||
// BREAKING CHANGES are expected.
|
||||
//
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
// Note: this also call mtmd_log_set() internally
|
||||
MTMD_API void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
// helper function to construct a mtmd_bitmap from a file
|
||||
// it calls mtmd_helper_bitmap_init_from_buf() internally
|
||||
// returns nullptr on failure
|
||||
|
||||
+5
-2
@@ -105,7 +105,6 @@ mtmd_context_params mtmd_context_params_default() {
|
||||
/* use_gpu */ true,
|
||||
/* print_timings */ true,
|
||||
/* n_threads */ 4,
|
||||
/* verbosity */ GGML_LOG_LEVEL_INFO,
|
||||
/* image_marker */ MTMD_DEFAULT_IMAGE_MARKER,
|
||||
/* media_marker */ mtmd_default_marker(),
|
||||
/* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO,
|
||||
@@ -175,7 +174,6 @@ struct mtmd_context {
|
||||
|
||||
clip_context_params ctx_clip_params {
|
||||
/* use_gpu */ ctx_params.use_gpu,
|
||||
/* verbosity */ ctx_params.verbosity,
|
||||
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
|
||||
/* image_min_tokens */ ctx_params.image_min_tokens,
|
||||
/* image_max_tokens */ ctx_params.image_max_tokens,
|
||||
@@ -1096,3 +1094,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
void mtmd_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||
g_logger_state.log_callback = log_callback ? log_callback : clip_log_callback_default;
|
||||
g_logger_state.log_callback_user_data = user_data;
|
||||
}
|
||||
|
||||
+4
-1
@@ -79,7 +79,6 @@ struct mtmd_context_params {
|
||||
bool use_gpu;
|
||||
bool print_timings;
|
||||
int n_threads;
|
||||
enum ggml_log_level verbosity;
|
||||
const char * image_marker; // deprecated, use media_marker instead
|
||||
const char * media_marker;
|
||||
enum llama_flash_attn_type flash_attn_type;
|
||||
@@ -215,6 +214,10 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
||||
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
|
||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
MTMD_API void mtmd_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
/////////////////////////////////////////
|
||||
|
||||
// test function, to be used in test-mtmd-c-api.c
|
||||
|
||||
Binary file not shown.
+32
-30
@@ -1686,14 +1686,13 @@ struct server_slot {
|
||||
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
|
||||
}
|
||||
|
||||
void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
||||
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
||||
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
||||
if (!res) {
|
||||
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||
prompt.tokens.clear();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<common_adapter_lora_info> lora;
|
||||
@@ -2339,7 +2338,6 @@ struct server_context {
|
||||
|
||||
llama_batch batch {};
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
@@ -2454,11 +2452,12 @@ struct server_context {
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mparams.flash_attn_type = params_base.flash_attn_type;
|
||||
mparams.image_min_tokens = params_base.image_min_tokens;
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
@@ -2701,7 +2700,10 @@ struct server_context {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
ret->prompt_save(*prompt_cache);
|
||||
ret->prompt_load(*prompt_cache, task.tokens);
|
||||
|
||||
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
|
||||
clear_slot(*ret);
|
||||
}
|
||||
|
||||
prompt_cache->update();
|
||||
|
||||
@@ -2712,12 +2714,21 @@ struct server_context {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// return true if at least one slot has been purged
|
||||
void clear_slot(server_slot & slot) const {
|
||||
GGML_ASSERT(!slot.is_processing());
|
||||
|
||||
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
}
|
||||
|
||||
// return true if at least one slot has been cleared
|
||||
// TODO: improve logic
|
||||
// - smarter decision which slot to purge (LRU or longest prompt?)
|
||||
// - smarter decision which slot to clear (LRU or longest prompt?)
|
||||
// - move slot to level 2 cache instead of removing?
|
||||
// - instead of purging, try to store and resume later?
|
||||
bool try_purge_idle_slots() {
|
||||
bool try_clear_idle_slots() {
|
||||
bool res = false;
|
||||
|
||||
if (!params_base.kv_unified) {
|
||||
@@ -2732,12 +2743,11 @@ struct server_context {
|
||||
if (slot.prompt.n_tokens() > 0) {
|
||||
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
clear_slot(slot);
|
||||
|
||||
res = true;
|
||||
|
||||
// purge slots one by one
|
||||
// clear slots one by one
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -2847,14 +2857,6 @@ struct server_context {
|
||||
return true;
|
||||
}
|
||||
|
||||
void kv_cache_clear() {
|
||||
SRV_DBG("%s", "clearing KV cache\n");
|
||||
|
||||
// clear the entire KV cache
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
clean_kv_cache = false;
|
||||
}
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = result.text_to_send;
|
||||
@@ -3442,8 +3444,8 @@ struct server_context {
|
||||
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->prompt.tokens.size();
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
|
||||
slot->prompt.tokens.clear();
|
||||
|
||||
clear_slot(*slot);
|
||||
|
||||
auto res = std::make_unique<server_task_result_slot_erase>();
|
||||
res->id = task.id;
|
||||
@@ -3476,9 +3478,6 @@ struct server_context {
|
||||
|
||||
if (all_idle) {
|
||||
SRV_INF("%s", "all slots are idle\n");
|
||||
if (clean_kv_cache) {
|
||||
kv_cache_clear();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -3872,12 +3871,11 @@ struct server_context {
|
||||
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
|
||||
clear_slot(slot);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_prompt_tokens_cache = 0;
|
||||
|
||||
slot.prompt.tokens.clear();
|
||||
}
|
||||
|
||||
// check if we should process the image
|
||||
@@ -4107,6 +4105,10 @@ struct server_context {
|
||||
if (slot.is_processing()) {
|
||||
send_error(slot, err);
|
||||
slot.release();
|
||||
|
||||
// note: it's complicated to keep track of how much of the current batch has been
|
||||
// processed before the error occurred, so we simply clear the entire context
|
||||
clear_slot(slot);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4115,7 +4117,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
if (!try_purge_idle_slots()) {
|
||||
if (!try_clear_idle_slots()) {
|
||||
n_batch /= 2;
|
||||
}
|
||||
|
||||
|
||||
+1
-7
@@ -72,12 +72,6 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (isOpen) {
|
||||
updateMenuPosition();
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSelect(value: string | undefined) {
|
||||
if (!value) return;
|
||||
|
||||
@@ -259,7 +253,7 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
<svelte:window onresize={handleResize} onscroll={handleScroll} />
|
||||
<svelte:window onresize={handleResize} />
|
||||
|
||||
<svelte:document onpointerdown={handlePointerDown} onkeydown={handleKeydown} />
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { getDeletionInfo } from '$lib/stores/chat.svelte';
|
||||
import { copyToClipboard } from '$lib/utils/copy';
|
||||
import { isIMEComposing } from '$lib/utils/is-ime-composing';
|
||||
import type { ApiChatCompletionToolCall } from '$lib/types/api';
|
||||
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
||||
import ChatMessageUser from './ChatMessageUser.svelte';
|
||||
|
||||
@@ -54,6 +55,29 @@
|
||||
return null;
|
||||
});
|
||||
|
||||
let toolCallContent = $derived.by((): ApiChatCompletionToolCall[] | string | null => {
|
||||
if (message.role === 'assistant') {
|
||||
const trimmedToolCalls = message.toolCalls?.trim();
|
||||
|
||||
if (!trimmedToolCalls) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(trimmedToolCalls);
|
||||
|
||||
if (Array.isArray(parsed)) {
|
||||
return parsed as ApiChatCompletionToolCall[];
|
||||
}
|
||||
} catch {
|
||||
// Harmony-only path: fall back to the raw string so issues surface visibly.
|
||||
}
|
||||
|
||||
return trimmedToolCalls;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
function handleCancelEdit() {
|
||||
isEditing = false;
|
||||
editedContent = message.content;
|
||||
@@ -171,5 +195,6 @@
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
{thinkingContent}
|
||||
{toolCallContent}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
+117
-2
@@ -11,7 +11,8 @@
|
||||
Gauge,
|
||||
Clock,
|
||||
WholeWord,
|
||||
ChartNoAxesColumn
|
||||
ChartNoAxesColumn,
|
||||
Wrench
|
||||
} from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
@@ -21,6 +22,7 @@
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { modelName as serverModelName } from '$lib/stores/server.svelte';
|
||||
import { copyToClipboard } from '$lib/utils/copy';
|
||||
import type { ApiChatCompletionToolCall } from '$lib/types/api';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -51,6 +53,7 @@
|
||||
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||
textareaElement?: HTMLTextAreaElement;
|
||||
thinkingContent: string | null;
|
||||
toolCallContent: ApiChatCompletionToolCall[] | string | null;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -76,9 +79,15 @@
|
||||
shouldBranchAfterEdit = false,
|
||||
siblingInfo = null,
|
||||
textareaElement = $bindable(),
|
||||
thinkingContent
|
||||
thinkingContent,
|
||||
toolCallContent = null
|
||||
}: Props = $props();
|
||||
|
||||
const toolCalls = $derived(
|
||||
Array.isArray(toolCallContent) ? (toolCallContent as ApiChatCompletionToolCall[]) : null
|
||||
);
|
||||
const fallbackToolCalls = $derived(typeof toolCallContent === 'string' ? toolCallContent : null);
|
||||
|
||||
const processingState = useProcessingState();
|
||||
let currentConfig = $derived(config());
|
||||
let serverModel = $derived(serverModelName());
|
||||
@@ -97,6 +106,58 @@
|
||||
|
||||
void copyToClipboard(model ?? '');
|
||||
}
|
||||
|
||||
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
|
||||
const callNumber = index + 1;
|
||||
const functionName = toolCall.function?.name?.trim();
|
||||
const label = functionName || `Call #${callNumber}`;
|
||||
|
||||
const payload: Record<string, unknown> = {};
|
||||
|
||||
const id = toolCall.id?.trim();
|
||||
if (id) {
|
||||
payload.id = id;
|
||||
}
|
||||
|
||||
const type = toolCall.type?.trim();
|
||||
if (type) {
|
||||
payload.type = type;
|
||||
}
|
||||
|
||||
if (toolCall.function) {
|
||||
const fnPayload: Record<string, unknown> = {};
|
||||
|
||||
const name = toolCall.function.name?.trim();
|
||||
if (name) {
|
||||
fnPayload.name = name;
|
||||
}
|
||||
|
||||
const rawArguments = toolCall.function.arguments?.trim();
|
||||
if (rawArguments) {
|
||||
try {
|
||||
fnPayload.arguments = JSON.parse(rawArguments);
|
||||
} catch {
|
||||
fnPayload.arguments = rawArguments;
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(fnPayload).length > 0) {
|
||||
payload.function = fnPayload;
|
||||
}
|
||||
}
|
||||
|
||||
const formattedPayload = JSON.stringify(payload, null, 2);
|
||||
|
||||
return {
|
||||
label,
|
||||
tooltip: formattedPayload,
|
||||
copyValue: formattedPayload
|
||||
};
|
||||
}
|
||||
|
||||
function handleCopyToolCall(payload: string) {
|
||||
void copyToClipboard(payload, 'Tool call copied to clipboard');
|
||||
}
|
||||
</script>
|
||||
|
||||
<div
|
||||
@@ -189,6 +250,47 @@
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
{#if config().showToolCalls}
|
||||
{#if (toolCalls && toolCalls.length > 0) || fallbackToolCalls}
|
||||
<span class="inline-flex flex-wrap items-center gap-2 text-xs text-muted-foreground">
|
||||
<span class="inline-flex items-center gap-1">
|
||||
<Wrench class="h-3.5 w-3.5" />
|
||||
|
||||
<span>Tool calls:</span>
|
||||
</span>
|
||||
|
||||
{#if toolCalls && toolCalls.length > 0}
|
||||
{#each toolCalls as toolCall, index (toolCall.id ?? `${index}`)}
|
||||
{@const badge = formatToolCallBadge(toolCall, index)}
|
||||
<button
|
||||
type="button"
|
||||
class="tool-call-badge inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||
title={badge.tooltip}
|
||||
aria-label={`Copy tool call ${badge.label}`}
|
||||
onclick={() => handleCopyToolCall(badge.copyValue)}
|
||||
>
|
||||
{badge.label}
|
||||
|
||||
<Copy class="ml-1 h-3 w-3" />
|
||||
</button>
|
||||
{/each}
|
||||
{:else if fallbackToolCalls}
|
||||
<button
|
||||
type="button"
|
||||
class="tool-call-badge tool-call-badge--fallback inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
|
||||
title={fallbackToolCalls}
|
||||
aria-label="Copy tool call payload"
|
||||
onclick={() => handleCopyToolCall(fallbackToolCalls)}
|
||||
>
|
||||
{fallbackToolCalls}
|
||||
|
||||
<Copy class="ml-1 h-3 w-3" />
|
||||
</button>
|
||||
{/if}
|
||||
</span>
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
|
||||
{@const tokensPerSecond = (message.timings.predicted_n / message.timings.predicted_ms) * 1000}
|
||||
<span class="inline-flex items-center gap-2 text-xs text-muted-foreground">
|
||||
@@ -287,4 +389,17 @@
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.tool-call-badge {
|
||||
max-width: 12rem;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.tool-call-badge--fallback {
|
||||
max-width: 20rem;
|
||||
white-space: normal;
|
||||
word-break: break-word;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -76,10 +76,10 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="chat-processing-info-container" class:visible={showSlotsInfo}>
|
||||
<div class="chat-processing-info-container pointer-events-none" class:visible={showSlotsInfo}>
|
||||
<div class="chat-processing-info-content">
|
||||
{#each processingDetails as detail (detail)}
|
||||
<span class="chat-processing-info-detail">{detail}</span>
|
||||
<span class="chat-processing-info-detail pointer-events-auto">{detail}</span>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
@@ -92,7 +92,6 @@
|
||||
padding: 1.5rem 1rem;
|
||||
opacity: 0;
|
||||
transform: translateY(50%);
|
||||
pointer-events: none;
|
||||
transition:
|
||||
opacity 300ms ease-out,
|
||||
transform 300ms ease-out;
|
||||
@@ -100,7 +99,6 @@
|
||||
|
||||
.chat-processing-info-container.visible {
|
||||
opacity: 1;
|
||||
pointer-events: auto;
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
|
||||
@@ -226,6 +226,11 @@
|
||||
label: 'Enable model selector',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'showToolCalls',
|
||||
label: 'Show tool call labels',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'disableReasoningFormat',
|
||||
label: 'Show raw LLM output',
|
||||
|
||||
@@ -6,6 +6,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||
theme: 'system',
|
||||
showTokensPerSecond: false,
|
||||
showThoughtInProgress: false,
|
||||
showToolCalls: false,
|
||||
disableReasoningFormat: false,
|
||||
keepStatsVisible: false,
|
||||
showMessageStats: true,
|
||||
@@ -80,6 +81,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
custom: 'Custom JSON parameters to send to the API. Must be valid JSON format.',
|
||||
showTokensPerSecond: 'Display generation speed in tokens per second during streaming.',
|
||||
showThoughtInProgress: 'Expand thought process by default when generating messages.',
|
||||
showToolCalls:
|
||||
'Display tool call labels and payloads from Harmony-compatible delta.tool_calls data below assistant messages.',
|
||||
disableReasoningFormat:
|
||||
'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.',
|
||||
keepStatsVisible: 'Keep processing statistics visible after generation finishes.',
|
||||
|
||||
@@ -1,6 +1,25 @@
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { selectedModelName } from '$lib/stores/models.svelte';
|
||||
import { slotsService } from './slots';
|
||||
import type {
|
||||
ApiChatCompletionRequest,
|
||||
ApiChatCompletionResponse,
|
||||
ApiChatCompletionStreamChunk,
|
||||
ApiChatCompletionToolCall,
|
||||
ApiChatCompletionToolCallDelta,
|
||||
ApiChatMessageData
|
||||
} from '$lib/types/api';
|
||||
import type {
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraLegacyContext,
|
||||
DatabaseMessageExtraPdfFile,
|
||||
DatabaseMessageExtraTextFile
|
||||
} from '$lib/types/database';
|
||||
import type { ChatMessagePromptProgress, ChatMessageTimings } from '$lib/types/chat';
|
||||
import type { SettingsChatServiceOptions } from '$lib/types/settings';
|
||||
/**
|
||||
* ChatService - Low-level API communication layer for llama.cpp server interactions
|
||||
*
|
||||
@@ -53,6 +72,7 @@ export class ChatService {
|
||||
onComplete,
|
||||
onError,
|
||||
onReasoningChunk,
|
||||
onToolCallChunk,
|
||||
onModel,
|
||||
onFirstValidChunk,
|
||||
// Generation parameters
|
||||
@@ -201,6 +221,7 @@ export class ChatService {
|
||||
onComplete,
|
||||
onError,
|
||||
onReasoningChunk,
|
||||
onToolCallChunk,
|
||||
onModel,
|
||||
onFirstValidChunk,
|
||||
conversationId,
|
||||
@@ -208,7 +229,13 @@ export class ChatService {
|
||||
);
|
||||
return;
|
||||
} else {
|
||||
return this.handleNonStreamResponse(response, onComplete, onError, onModel);
|
||||
return this.handleNonStreamResponse(
|
||||
response,
|
||||
onComplete,
|
||||
onError,
|
||||
onToolCallChunk,
|
||||
onModel
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
@@ -264,10 +291,12 @@ export class ChatService {
|
||||
onComplete?: (
|
||||
response: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings
|
||||
timings?: ChatMessageTimings,
|
||||
toolCalls?: string
|
||||
) => void,
|
||||
onError?: (error: Error) => void,
|
||||
onReasoningChunk?: (chunk: string) => void,
|
||||
onToolCallChunk?: (chunk: string) => void,
|
||||
onModel?: (model: string) => void,
|
||||
onFirstValidChunk?: () => void,
|
||||
conversationId?: string,
|
||||
@@ -282,11 +311,53 @@ export class ChatService {
|
||||
const decoder = new TextDecoder();
|
||||
let aggregatedContent = '';
|
||||
let fullReasoningContent = '';
|
||||
let aggregatedToolCalls: ApiChatCompletionToolCall[] = [];
|
||||
let hasReceivedData = false;
|
||||
let lastTimings: ChatMessageTimings | undefined;
|
||||
let streamFinished = false;
|
||||
let modelEmitted = false;
|
||||
let firstValidChunkEmitted = false;
|
||||
let toolCallIndexOffset = 0;
|
||||
let hasOpenToolCallBatch = false;
|
||||
|
||||
const finalizeOpenToolCallBatch = () => {
|
||||
if (!hasOpenToolCallBatch) {
|
||||
return;
|
||||
}
|
||||
|
||||
toolCallIndexOffset = aggregatedToolCalls.length;
|
||||
hasOpenToolCallBatch = false;
|
||||
};
|
||||
|
||||
const processToolCallDelta = (toolCalls?: ApiChatCompletionToolCallDelta[]) => {
|
||||
if (!toolCalls || toolCalls.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
aggregatedToolCalls = this.mergeToolCallDeltas(
|
||||
aggregatedToolCalls,
|
||||
toolCalls,
|
||||
toolCallIndexOffset
|
||||
);
|
||||
|
||||
if (aggregatedToolCalls.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
hasOpenToolCallBatch = true;
|
||||
|
||||
const serializedToolCalls = JSON.stringify(aggregatedToolCalls);
|
||||
|
||||
if (!serializedToolCalls) {
|
||||
return;
|
||||
}
|
||||
|
||||
hasReceivedData = true;
|
||||
|
||||
if (!abortSignal?.aborted) {
|
||||
onToolCallChunk?.(serializedToolCalls);
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
let chunk = '';
|
||||
@@ -325,6 +396,7 @@ export class ChatService {
|
||||
|
||||
const content = parsed.choices[0]?.delta?.content;
|
||||
const reasoningContent = parsed.choices[0]?.delta?.reasoning_content;
|
||||
const toolCalls = parsed.choices[0]?.delta?.tool_calls;
|
||||
const timings = parsed.timings;
|
||||
const promptProgress = parsed.prompt_progress;
|
||||
|
||||
@@ -342,6 +414,7 @@ export class ChatService {
|
||||
}
|
||||
|
||||
if (content) {
|
||||
finalizeOpenToolCallBatch();
|
||||
hasReceivedData = true;
|
||||
aggregatedContent += content;
|
||||
if (!abortSignal?.aborted) {
|
||||
@@ -350,12 +423,15 @@ export class ChatService {
|
||||
}
|
||||
|
||||
if (reasoningContent) {
|
||||
finalizeOpenToolCallBatch();
|
||||
hasReceivedData = true;
|
||||
fullReasoningContent += reasoningContent;
|
||||
if (!abortSignal?.aborted) {
|
||||
onReasoningChunk?.(reasoningContent);
|
||||
}
|
||||
}
|
||||
|
||||
processToolCallDelta(toolCalls);
|
||||
} catch (e) {
|
||||
console.error('Error parsing JSON chunk:', e);
|
||||
}
|
||||
@@ -368,12 +444,26 @@ export class ChatService {
|
||||
if (abortSignal?.aborted) return;
|
||||
|
||||
if (streamFinished) {
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
finalizeOpenToolCallBatch();
|
||||
|
||||
if (
|
||||
!hasReceivedData &&
|
||||
aggregatedContent.length === 0 &&
|
||||
aggregatedToolCalls.length === 0
|
||||
) {
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
|
||||
const finalToolCalls =
|
||||
aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined;
|
||||
|
||||
onComplete?.(
|
||||
aggregatedContent,
|
||||
fullReasoningContent || undefined,
|
||||
lastTimings,
|
||||
finalToolCalls
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
const err = error instanceof Error ? error : new Error('Stream error');
|
||||
@@ -386,6 +476,54 @@ export class ChatService {
|
||||
}
|
||||
}
|
||||
|
||||
private mergeToolCallDeltas(
|
||||
existing: ApiChatCompletionToolCall[],
|
||||
deltas: ApiChatCompletionToolCallDelta[],
|
||||
indexOffset = 0
|
||||
): ApiChatCompletionToolCall[] {
|
||||
const result = existing.map((call) => ({
|
||||
...call,
|
||||
function: call.function ? { ...call.function } : undefined
|
||||
}));
|
||||
|
||||
for (const delta of deltas) {
|
||||
const index =
|
||||
typeof delta.index === 'number' && delta.index >= 0
|
||||
? delta.index + indexOffset
|
||||
: result.length;
|
||||
|
||||
while (result.length <= index) {
|
||||
result.push({ function: undefined });
|
||||
}
|
||||
|
||||
const target = result[index]!;
|
||||
|
||||
if (delta.id) {
|
||||
target.id = delta.id;
|
||||
}
|
||||
|
||||
if (delta.type) {
|
||||
target.type = delta.type;
|
||||
}
|
||||
|
||||
if (delta.function) {
|
||||
const fn = target.function ? { ...target.function } : {};
|
||||
|
||||
if (delta.function.name) {
|
||||
fn.name = delta.function.name;
|
||||
}
|
||||
|
||||
if (delta.function.arguments) {
|
||||
fn.arguments = (fn.arguments ?? '') + delta.function.arguments;
|
||||
}
|
||||
|
||||
target.function = fn;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles non-streaming response from the chat completion API.
|
||||
* Parses the JSON response and extracts the generated content.
|
||||
@@ -401,9 +539,11 @@ export class ChatService {
|
||||
onComplete?: (
|
||||
response: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings
|
||||
timings?: ChatMessageTimings,
|
||||
toolCalls?: string
|
||||
) => void,
|
||||
onError?: (error: Error) => void,
|
||||
onToolCallChunk?: (chunk: string) => void,
|
||||
onModel?: (model: string) => void
|
||||
): Promise<string> {
|
||||
try {
|
||||
@@ -423,17 +563,31 @@ export class ChatService {
|
||||
|
||||
const content = data.choices[0]?.message?.content || '';
|
||||
const reasoningContent = data.choices[0]?.message?.reasoning_content;
|
||||
const toolCalls = data.choices[0]?.message?.tool_calls;
|
||||
|
||||
if (reasoningContent) {
|
||||
console.log('Full reasoning content:', reasoningContent);
|
||||
}
|
||||
|
||||
if (!content.trim()) {
|
||||
let serializedToolCalls: string | undefined;
|
||||
|
||||
if (toolCalls && toolCalls.length > 0) {
|
||||
const mergedToolCalls = this.mergeToolCallDeltas([], toolCalls);
|
||||
|
||||
if (mergedToolCalls.length > 0) {
|
||||
serializedToolCalls = JSON.stringify(mergedToolCalls);
|
||||
if (serializedToolCalls) {
|
||||
onToolCallChunk?.(serializedToolCalls);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!content.trim() && !serializedToolCalls) {
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
onComplete?.(content, reasoningContent);
|
||||
onComplete?.(content, reasoningContent, undefined, serializedToolCalls);
|
||||
|
||||
return content;
|
||||
} catch (error) {
|
||||
|
||||
@@ -205,6 +205,7 @@ class ChatStore {
|
||||
type,
|
||||
timestamp: Date.now(),
|
||||
thinking: '',
|
||||
toolCalls: '',
|
||||
children: [],
|
||||
extra: extras
|
||||
},
|
||||
@@ -360,6 +361,7 @@ class ChatStore {
|
||||
): Promise<void> {
|
||||
let streamedContent = '';
|
||||
let streamedReasoningContent = '';
|
||||
let streamedToolCallContent = '';
|
||||
|
||||
let resolvedModel: string | null = null;
|
||||
let modelPersisted = false;
|
||||
@@ -468,6 +470,20 @@ class ChatStore {
|
||||
this.updateMessageAtIndex(messageIndex, { thinking: streamedReasoningContent });
|
||||
},
|
||||
|
||||
onToolCallChunk: (toolCallChunk: string) => {
|
||||
const chunk = toolCallChunk.trim();
|
||||
|
||||
if (!chunk) {
|
||||
return;
|
||||
}
|
||||
|
||||
streamedToolCallContent = chunk;
|
||||
|
||||
const messageIndex = this.findMessageIndex(assistantMessage.id);
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, { toolCalls: streamedToolCallContent });
|
||||
},
|
||||
|
||||
onModel: (modelName: string) => {
|
||||
recordModel(modelName);
|
||||
},
|
||||
@@ -475,18 +491,21 @@ class ChatStore {
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings
|
||||
timings?: ChatMessageTimings,
|
||||
toolCallContent?: string
|
||||
) => {
|
||||
slotsService.stopStreaming();
|
||||
|
||||
const updateData: {
|
||||
content: string;
|
||||
thinking: string;
|
||||
toolCalls: string;
|
||||
timings?: ChatMessageTimings;
|
||||
model?: string;
|
||||
} = {
|
||||
content: finalContent || streamedContent,
|
||||
thinking: reasoningContent || streamedReasoningContent,
|
||||
toolCalls: toolCallContent || streamedToolCallContent,
|
||||
timings: timings
|
||||
};
|
||||
|
||||
@@ -499,7 +518,11 @@ class ChatStore {
|
||||
|
||||
const messageIndex = this.findMessageIndex(assistantMessage.id);
|
||||
|
||||
const localUpdateData: { timings?: ChatMessageTimings; model?: string } = {
|
||||
const localUpdateData: {
|
||||
timings?: ChatMessageTimings;
|
||||
model?: string;
|
||||
toolCalls?: string;
|
||||
} = {
|
||||
timings: timings
|
||||
};
|
||||
|
||||
@@ -507,6 +530,10 @@ class ChatStore {
|
||||
localUpdateData.model = updateData.model;
|
||||
}
|
||||
|
||||
if (updateData.toolCalls !== undefined) {
|
||||
localUpdateData.toolCalls = updateData.toolCalls;
|
||||
}
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, localUpdateData);
|
||||
|
||||
await DatabaseStore.updateCurrentNode(assistantMessage.convId, assistantMessage.id);
|
||||
@@ -620,6 +647,7 @@ class ChatStore {
|
||||
content: '',
|
||||
timestamp: Date.now(),
|
||||
thinking: '',
|
||||
toolCalls: '',
|
||||
children: [],
|
||||
model: null
|
||||
},
|
||||
@@ -1443,6 +1471,7 @@ class ChatStore {
|
||||
role: messageToEdit.role,
|
||||
content: newContent,
|
||||
thinking: messageToEdit.thinking || '',
|
||||
toolCalls: messageToEdit.toolCalls || '',
|
||||
children: [],
|
||||
model: messageToEdit.model // Preserve original model info when branching
|
||||
},
|
||||
@@ -1518,6 +1547,7 @@ class ChatStore {
|
||||
role: messageToEdit.role,
|
||||
content: newContent,
|
||||
thinking: messageToEdit.thinking || '',
|
||||
toolCalls: messageToEdit.toolCalls || '',
|
||||
children: [],
|
||||
extra: messageToEdit.extra ? JSON.parse(JSON.stringify(messageToEdit.extra)) : undefined,
|
||||
model: messageToEdit.model // Preserve original model info when branching
|
||||
@@ -1589,6 +1619,7 @@ class ChatStore {
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
thinking: '',
|
||||
toolCalls: '',
|
||||
children: [],
|
||||
model: null
|
||||
},
|
||||
@@ -1647,6 +1678,7 @@ class ChatStore {
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
thinking: '',
|
||||
toolCalls: '',
|
||||
children: [],
|
||||
model: null
|
||||
},
|
||||
|
||||
@@ -114,6 +114,7 @@ export class DatabaseStore {
|
||||
...message,
|
||||
id: uuid(),
|
||||
parent: parentId,
|
||||
toolCalls: message.toolCalls ?? '',
|
||||
children: []
|
||||
};
|
||||
|
||||
@@ -154,6 +155,7 @@ export class DatabaseStore {
|
||||
content: '',
|
||||
parent: null,
|
||||
thinking: '',
|
||||
toolCalls: '',
|
||||
children: []
|
||||
};
|
||||
|
||||
|
||||
+19
@@ -183,6 +183,23 @@ export interface ApiChatCompletionRequest {
|
||||
samplers?: string[];
|
||||
// Custom parameters (JSON string)
|
||||
custom?: Record<string, unknown>;
|
||||
timings_per_token?: boolean;
|
||||
}
|
||||
|
||||
export interface ApiChatCompletionToolCallFunctionDelta {
|
||||
name?: string;
|
||||
arguments?: string;
|
||||
}
|
||||
|
||||
export interface ApiChatCompletionToolCallDelta {
|
||||
index?: number;
|
||||
id?: string;
|
||||
type?: string;
|
||||
function?: ApiChatCompletionToolCallFunctionDelta;
|
||||
}
|
||||
|
||||
export interface ApiChatCompletionToolCall extends ApiChatCompletionToolCallDelta {
|
||||
function?: ApiChatCompletionToolCallFunctionDelta & { arguments?: string };
|
||||
}
|
||||
|
||||
export interface ApiChatCompletionStreamChunk {
|
||||
@@ -195,6 +212,7 @@ export interface ApiChatCompletionStreamChunk {
|
||||
content?: string;
|
||||
reasoning_content?: string;
|
||||
model?: string;
|
||||
tool_calls?: ApiChatCompletionToolCallDelta[];
|
||||
};
|
||||
}>;
|
||||
timings?: {
|
||||
@@ -216,6 +234,7 @@ export interface ApiChatCompletionResponse {
|
||||
content: string;
|
||||
reasoning_content?: string;
|
||||
model?: string;
|
||||
tool_calls?: ApiChatCompletionToolCallDelta[];
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
@@ -60,6 +60,7 @@ export interface DatabaseMessage {
|
||||
content: string;
|
||||
parent: string;
|
||||
thinking: string;
|
||||
toolCalls?: string;
|
||||
children: string[];
|
||||
extra?: DatabaseMessageExtra[];
|
||||
timings?: ChatMessageTimings;
|
||||
|
||||
+8
-1
@@ -38,12 +38,19 @@ export interface SettingsChatServiceOptions {
|
||||
samplers?: string | string[];
|
||||
// Custom parameters
|
||||
custom?: string;
|
||||
timings_per_token?: boolean;
|
||||
// Callbacks
|
||||
onChunk?: (chunk: string) => void;
|
||||
onReasoningChunk?: (chunk: string) => void;
|
||||
onToolCallChunk?: (chunk: string) => void;
|
||||
onModel?: (model: string) => void;
|
||||
onFirstValidChunk?: () => void;
|
||||
onComplete?: (response: string, reasoningContent?: string, timings?: ChatMessageTimings) => void;
|
||||
onComplete?: (
|
||||
response: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings,
|
||||
toolCalls?: string
|
||||
) => void;
|
||||
onError?: (error: Error) => void;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user