Compare commits

...

22 Commits

Author SHA1 Message Date
Georgi Gerganov 85cc1ae998 context : print graph stats for memory-less contexts (#15586)
ggml-ci
2025-08-26 12:47:00 +03:00
Georgi Gerganov 1d8d83deaa metal : improve MUL_MAT_ID (#15541)
* metal : mul_mm_id remove hdst

* metal : remove mul_mm_id hsrc1

* metal : mul_mm_id simplify + add test

* metal : opt mul_mm_id map0

* metal : optimize mul_mm_id id gathering

* metal : mul/div opt

* metal : optimize mul_mm_id_map0

ggml-ci
2025-08-26 12:46:15 +03:00
tc-mb c4e9239064 model : support MiniCPM-V 4.5 (#15575) 2025-08-26 10:05:55 +02:00
Sigbjørn Skjæret 39842a7f73 gguf-py : remove erroneous FFN_GATE entry (#15583) 2025-08-26 09:08:08 +02:00
Sigbjørn Skjæret 0fd90db585 metal : remove contiguous assertion for src0 in IM2COL (#15577)
* remove contiguous assertion for src0 in IM2COL

* add contiguous check in supports_op
2025-08-26 09:51:43 +03:00
Yoshi_likes_e4 4c37636b3e Add a warning for special devices (#15563)
* Add warning

* Print the devices names

* Add newlines

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Fix vector names

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-08-26 08:15:33 +02:00
Jeff Bolz 34bdbbd7c2 vulkan: Remove splitting for mul_mat_id (#15568)
row_ids only needs to hold the BN rows for the current tile.
2025-08-26 06:42:44 +02:00
Qeeweew 74f52f77f2 CUDA: Accelerate MXFP4 table lookup using __byte_perm (#15451)
* CUDA: optimize get_int_from_table_16

* CUDA: use v_perm_b32 to replace byte_perm on AMD GPUs

* revise documentation

---------

Co-authored-by: xix <xiapc@outlook.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-08-25 23:21:22 +02:00
lhez f7207b0415 opencl: fix support ops condition for rms_norm (#15560) 2025-08-25 14:18:09 -07:00
Ruben Ortlam 4d917cd4f6 vulkan: fix min subgroup 16 condition for mmid subgroup optimization (#15565) 2025-08-25 17:56:59 +02:00
Jeff Bolz 886b97a5d6 tests: Generate unique input values for count_equal (#15487)
This avoids backend-dependent behavior for argmax that leads to intermittent failures.
2025-08-25 10:47:16 -05:00
Ihar Hrachyshka 111f8d06f0 metal: fix regression when no metal devices are present (#15531) 2025-08-25 18:27:34 +03:00
Johannes Gäßler 5eff6ec9b1 CUDA: MoE helper in device code, better tile sizes (#15525)
* CUDA: MoE helper in device code, better tile sizes

* reduce superfluous CUDA blocks
2025-08-25 17:23:40 +02:00
Daniel Bevenius dfd9b5f6c7 model-conversion : set pooling type to none in logits.cpp (#15564)
This commit explicitly sets the pooling type to 'none' in the logits.cpp
to support models that have a pooling type specified.

The motivation for this is that some models may have a pooling type set
in the model file (.gguf file) and for this specific case where we only
want to extract logits, we need to ensure that no pooling is used to
so that we are comparing raw logits and not pooled embeddings.
2025-08-25 15:00:43 +02:00
Daniel Bevenius 5a6bc6b1a6 model-conversion : add model card template for embeddings [no ci] (#15557)
* model-conversion: add model card template for embeddings [no ci]

This commit adds a separate model card template (model repository
README.md template) for embedding models.

The motivation for this is that there server command for the embedding
model is a little different and some addition information can be useful
in the model card for embedding models which might not be directly
relevant for causal models.

* squash! model-conversion: add model card template for embeddings [no ci]

Fix pyright lint error.

* remove --pooling override and clarify embd_normalize usage
2025-08-25 14:25:25 +02:00
Georgi Gerganov 6b64f74b55 batched-bench : fix unified KV cache handling + pp timing (#15562)
* batched-bench : fix unified KV cache handling + pp timing

* cont : run dummy token only with split KV cache
2025-08-25 13:56:43 +03:00
Weizhao Ouyang 0d5a470223 convert : update Ernie 4.5 dense architecture name (#15555)
Signed-off-by: Weizhao Ouyang <o451686892@gmail.com>
2025-08-25 11:15:06 +02:00
Georgi Gerganov b0ba31f525 metal : add FA kernels for HS=40 (#15559)
ggml-ci
2025-08-25 10:14:48 +03:00
RunningLeon 7da9fed0d6 convert : support interns1-mini (#15412)
* support interns1-mini

* fix comment

* update
2025-08-25 08:32:16 +02:00
Chenguang Li c247d06f38 CANN: ROPE cache sin/cos repeat (#15501)
Signed-off-by: noemotiovon <757486878@qq.com>
2025-08-25 10:32:21 +08:00
Ruben Ortlam 043fb27d38 vulkan: apply MUL_MAT_ID subgroup optimization to non-coopmat devices (#15524)
* vulkan: use subgroup function for mul_mat_id shader even without coopmat

* vulkan: fix compile warnings

* vulkan: properly check for subgroup size control and require full subgroups for subgroup mul_mat_id

* vulkan: disable subgroup mul_mat_id on devices with subgroups < 16
2025-08-24 19:36:36 +02:00
Georgi Gerganov b730706a49 kv-cache : support layer reuse (#15504)
* kv-cache : support layer reuse

ggml-ci

* cont : update comments [no ci]
2025-08-24 13:07:07 +03:00
44 changed files with 1493 additions and 895 deletions
+66 -69
View File
@@ -1216,6 +1216,55 @@ class TextModel(ModelBase):
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type)
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
vocab_size = self.hparams.get("vocab_size", len(vocab))
assert max(vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab._set_special_token("bos", 151643)
special_vocab.add_to_gguf(self.gguf_writer)
class MmprojModel(ModelBase):
model_type = ModelType.MMPROJ
@@ -2932,7 +2981,8 @@ class Qwen2Model(TextModel):
if "language_model." in name:
name = name.replace("language_model.", "") # for InternVL
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
or name.startswith("vision_model") or name.startswith("audio_tower"):
or name.startswith("vision_model") or name.startswith("audio_tower") \
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
# skip vision and audio tensors
return []
yield from super().modify_tensors(data_torch, name, bid)
@@ -3109,7 +3159,7 @@ class LLaDAModel(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Ernie4_5_ForCausalLM")
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5
@@ -3604,6 +3654,19 @@ class Qwen2MoeModel(TextModel):
class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
def set_vocab(self):
# deal with intern-s1-mini
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
self._set_vocab_interns1()
return
super().set_vocab()
@ModelBase.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel):
@@ -3620,73 +3683,7 @@ class Qwen3MoeModel(Qwen2MoeModel):
self._set_vocab_interns1()
return
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
vocab_size = self.hparams.get("vocab_size", len(vocab))
assert max(vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
additional_special_tokens = []
if special_tokens_map_file.is_file():
with open(special_tokens_map_file, encoding = 'utf-8') as f:
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
if tokenizer_cfg_file.is_file():
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
for token in additional_special_tokens:
if token in token2ids_map:
special_vocab._set_special_token(token, token2ids_map[token])
special_vocab._set_special_token('eos', 151645)
special_vocab._set_special_token("bos", 151643)
special_vocab.add_to_gguf(self.gguf_writer)
super().set_vocab()
@ModelBase.register("GPT2LMHeadModel")
+1 -1
View File
@@ -6,7 +6,7 @@ Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model
### Build llama.cpp
Readme modification time: 20250206
Readme modification time: 20250731
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
+47
View File
@@ -0,0 +1,47 @@
## MiniCPM-V 4.5
### Prepare models and code
Download [MiniCPM-V-4_5](https://huggingface.co/openbmb/MiniCPM-V-4_5) PyTorch model from huggingface to "MiniCPM-V-4_5" folder.
### Build llama.cpp
Readme modification time: 20250826
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
Clone llama.cpp:
```bash
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
```
Build llama.cpp using `CMake`:
```bash
cmake -B build
cmake --build build --config Release
```
### Usage of MiniCPM-V 4
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-4_5-gguf) by us)
```bash
python ./tools/mtmd/legacy-models/minicpmv-surgery.py -m ../MiniCPM-V-4_5
python ./tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-4_5 --minicpmv-projector ../MiniCPM-V-4_5/minicpmv.projector --output-dir ../MiniCPM-V-4_5/ --minicpmv_version 6
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_5/model
# quantize int4 version
./build/bin/llama-quantize ../MiniCPM-V-4_5/model/ggml-model-f16.gguf ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
```
Inference on Linux or Mac
```bash
# run in single-turn mode
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# run in conversation mode
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf
```
+9
View File
@@ -144,6 +144,15 @@ perplexity-run:
hf-create-model:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}"
hf-create-model-dry-run:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -d
hf-create-model-embedding:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e
hf-create-model-embedding-dry-run:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e -d
hf-create-model-private:
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -p
+9 -1
View File
@@ -285,13 +285,21 @@ For the following targets a `HF_TOKEN` environment variable is required.
This will create a new model repsository on Hugging Face with the specified
model name.
```console
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev"
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
Repository ID: danbev/TestModel-GGUF
Repository created: https://huggingface.co/danbev/TestModel-GGUF
```
Note that we append a `-GGUF` suffix to the model name to ensure a consistent
naming convention for GGUF models.
An embedding model can be created using the following command:
```console
(venv) $ make hf-create-model-embedding MODEL_NAME='TestEmbeddingModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
```
The only difference is that the model card for an embedding model will be different
with regards to the llama-server command and also how to access/call the embedding
endpoint.
### Upload a GGUF model to model repository
The following target uploads a model to an existing Hugging Face model repository.
```console
+1
View File
@@ -112,6 +112,7 @@ int main(int argc, char ** argv) {
ctx_params.no_perf = false;
if (embedding_mode) {
ctx_params.embeddings = true;
ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
ctx_params.n_ubatch = ctx_params.n_batch;
}
@@ -0,0 +1,48 @@
---
base_model:
- {base_model}
---
# {model_name} GGUF
Recommended way to run this model:
```sh
llama-server -hf {namespace}/{model_name}-GGUF
```
Then the endpoint can be accessed at http://localhost:8080/embedding, for
example using `curl`:
```console
curl --request POST \
--url http://localhost:8080/embedding \
--header "Content-Type: application/json" \
--data '{{"input": "Hello embeddings"}}' \
--silent
```
Alternatively, the `llama-embedding` command line tool can be used:
```sh
llama-embedding -hf {namespace}/{model_name}-GGUF --verbose-prompt -p "Hello embeddings"
```
#### embd_normalize
When a model uses pooling, or the pooling method is specified using `--pooling`,
the normalization can be controlled by the `embd_normalize` parameter.
The default value is `2` which means that the embeddings are normalized using
the Euclidean norm (L2). Other options are:
* -1 No normalization
* 0 Max absolute
* 1 Taxicab
* 2 Euclidean/L2
* \>2 P-Norm
This can be passed in the request body to `llama-server`, for example:
```sh
--data '{{"input": "Hello embeddings", "embd_normalize": -1}}' \
```
And for `llama-embedding`, by passing `--embd-normalize <value>`, for example:
```sh
llama-embedding -hf {namespace}/{model_name}-GGUF --embd-normalize -1 -p "Hello embeddings"
```
@@ -26,21 +26,31 @@ parser.add_argument('--namespace', '-ns', help='Namespace to add the model to',
parser.add_argument('--org-base-model', '-b', help='Original Base model name', default="")
parser.add_argument('--no-card', action='store_true', help='Skip creating model card')
parser.add_argument('--private', '-p', action='store_true', help='Create private model')
parser.add_argument('--embedding', '-e', action='store_true', help='Use embedding model card template')
parser.add_argument('--dry-run', '-d', action='store_true', help='Print repository info and template without creating repository')
args = parser.parse_args()
repo_id = f"{args.namespace}/{args.model_name}-GGUF"
print("Repository ID: ", repo_id)
repo_url = api.create_repo(
repo_id=repo_id,
repo_type="model",
private=args.private,
exist_ok=False
)
repo_url = None
if not args.dry_run:
repo_url = api.create_repo(
repo_id=repo_id,
repo_type="model",
private=args.private,
exist_ok=False
)
if not args.no_card:
template_path = "scripts/readme.md.template"
if args.embedding:
template_path = "scripts/embedding/modelcard.template"
else:
template_path = "scripts/causal/modelcard.template"
print("Template path: ", template_path)
model_card_content = load_template_and_substitute(
template_path,
model_name=args.model_name,
@@ -48,16 +58,21 @@ if not args.no_card:
base_model=args.org_base_model,
)
if model_card_content:
api.upload_file(
path_or_fileobj=model_card_content.encode('utf-8'),
path_in_repo="README.md",
repo_id=repo_id
)
print("Model card created successfully.")
if args.dry_run:
print("\nTemplate Content:\n")
print(model_card_content)
else:
print("Failed to create model card.")
if model_card_content:
api.upload_file(
path_or_fileobj=model_card_content.encode('utf-8'),
path_in_repo="README.md",
repo_id=repo_id
)
print("Model card created successfully.")
else:
print("Failed to create model card.")
print(f"Repository created: {repo_url}")
if not args.dry_run and repo_url:
print(f"Repository created: {repo_url}")
+117 -78
View File
@@ -1257,12 +1257,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
if(acl_dst == nullptr) {
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
} else {
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
}
}
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
if(acl_dst == nullptr) {
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
} else {
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
}
}
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@@ -2221,13 +2229,54 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
ggml_cann_release_resources(ctx, acl_index, acl_value);
}
/**
* @brief Initializes and caches sine/cosine positional encoding values
* (used in RoPE, Rotary Position Embedding) for attention layers.
*
* This function computes and caches the sin/cos values of
* θ = position * theta_scale for RoPE encoding. The cache is shared
* across attention layers, and only the first attention layer will
* trigger initialization. The cache includes repeated sin/cos values
* with different repeat methods depending on the @param is_neox flag.
*
* Steps performed by this function:
* 1. Identify whether the target tensor belongs to Q/K in attention
* and restrict computation to the first layer only.
* 2. Initialize the theta scale array (arange → power → freq scaling).
* 3. Allocate sin/cos caches if the max prompt length increases.
* 4. Compute θ = position * theta_scale.
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled.
* 7. Store the computed values into persistent buffers
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
*
* @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation
* depends on the cached RoPE values (usually Qcur/Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values.
* @param freq_scale Frequency scaling factor, applied to theta scale.
* @param attn_factor Attention scaling factor, applied to sin/cos.
* @param is_neox Whether to use Neox-style repeat strategy
* (dim expansion vs repeat_interleave).
*/
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
aclTensor* acl_cos_repeat_tensor,
aclTensor* acl_sin_repeat_tensor,
float theta_scale, float freq_scale,
float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
// @param.is_neox
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
// used for accuracy testing
bool is_attention = is_q || is_k;
// just compute in first layer in attention
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
if(is_attention && !is_fisrt_layer) {
return;
}
ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position
@@ -2253,21 +2302,16 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
}
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
// used for accuracy testing
bool is_attention = is_q || is_k;
if(ctx.init_ptr == nullptr || !is_attention) {
// init theta scale, just one time
if(ctx.rope_init_ptr == nullptr || !is_attention) {
// theta_scale arange, [0,1,...,ne00/2 - 1]
if(ctx.init_ptr != nullptr){
ACL_CHECK(aclrtFree(ctx.init_ptr));
if(ctx.rope_init_ptr != nullptr){
ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
}
ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float start = 0;
float step = 1;
@@ -2297,67 +2341,55 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
}
if(ctx.sin_ptr == nullptr) {
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
}
// init sin_repeat && cos_repeat, one token just init in 0 layer
if(position_length > ctx.max_prompt_length) {
ctx.max_prompt_length = position_length;
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
ACL_CHECK(aclrtFree(ctx.sin_ptr));
ACL_CHECK(aclrtFree(ctx.cos_ptr));
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
if(ctx.rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
}
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
if(is_fisrt_layer || !is_attention) {
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
// position
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
src1->data, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
// position
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
src1->data, ggml_cann_type_mapping(src1->type),
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
theta_length * sizeof(float_t));
void* theta_buffer = theta_allocator.get();
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
theta_length * sizeof(float_t));
void* theta_buffer = theta_allocator.get();
aclTensor* acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
acl_theta_tensor);
// sin/cos
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
// release
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
}
aclTensor* acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
acl_theta_tensor);
// sin/cos
ggml_cann_pool_alloc sin_allocator(ctx.pool(),
theta_length * sizeof(float_t));
void* sin_buffer = sin_allocator.get();
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
ggml_cann_pool_alloc cos_allocator(ctx.pool(),
theta_length * sizeof(float_t));
void* cos_buffer = cos_allocator.get();
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
// attn_factor
if (attn_factor != 1) {
@@ -2365,6 +2397,19 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
}
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat
if (is_neox) {
int64_t repeatsArray[] = {1, 1, 1, 2};
@@ -2380,8 +2425,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
num_repeats, output_size);
}
// release
ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
acl_cos_repeat_tensor);
}
#ifdef __cplusplus
@@ -2435,13 +2481,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
// init cos/sin cache
ggml_cann_pool_alloc sin_allocator(
ctx.pool(), ne00 * ne02 * sizeof(float_t));
ggml_cann_pool_alloc cos_allocator(
ctx.pool(), ne00 * ne02 * sizeof(float_t));
void* sin_buffer = sin_allocator.get();
void* cos_buffer = cos_allocator.get();
// init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2450,13 +2491,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_reshape_tensor =
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor =
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
theta_scale, freq_scale, attn_factor, is_neox);
aclTensor* acl_src = ggml_cann_create_tensor(src0);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+18 -10
View File
@@ -368,10 +368,6 @@ struct ggml_backend_cann_context {
std::string name; /**< Name of the device. */
std::string description; /**< Description of the device. */
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
void* init_ptr = nullptr;
void* sin_ptr = nullptr;
void* cos_ptr = nullptr;
int64_t max_prompt_length = 65536;
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr<ggml_cann_graph> cann_graph;
@@ -379,6 +375,12 @@ struct ggml_backend_cann_context {
cann_task_queue task_queue;
bool async_mode;
bool support_set_rows;
// Rope Cache
void* rope_init_ptr = nullptr;
void* rope_sin_ptr = nullptr;
void* rope_cos_ptr = nullptr;
int64_t max_prompt_length = 0;
// Constant Pool
void* f32_zero_cache = nullptr;
void* f32_one_cache = nullptr;
int64_t f32_zero_cache_element = 0;
@@ -422,14 +424,20 @@ struct ggml_backend_cann_context {
ACL_CHECK(aclrtDestroyStream(streams[i]));
}
}
if(init_ptr != nullptr) {
ACL_CHECK(aclrtFree(init_ptr));
if(rope_init_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_init_ptr));
}
if(sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(sin_ptr));
if(rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_sin_ptr));
}
if(cos_ptr != nullptr) {
ACL_CHECK(aclrtFree(cos_ptr));
if(rope_cos_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_cos_ptr));
}
if(f32_zero_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_zero_cache));
}
if(f32_one_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_one_cache));
}
}
+20 -8
View File
@@ -420,16 +420,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP
if (width == ggml_cuda_get_physical_warp_size()) {
return __all_sync(0xffffffff, x);
} else {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
for (int offset = width/2; offset > 0; offset >>= 1) {
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
}
return x;
}
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_any(int x) {
if (width == ggml_cuda_get_physical_warp_size()) {
return __any_sync(0xffffffff, x);
} else {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
}
return x;
}
return x;
#else
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
return __all_sync(0xffffffff, x);
#endif // GGML_USE_HIP
}
template<int width = WARP_SIZE>
+21 -1
View File
@@ -204,6 +204,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
#endif // GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0;
@@ -261,7 +263,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
#endif // defined(GGML_USE_HIP)
std::string device_name(prop.name);
if (device_name == "NVIDIA GeForce MX450") {
turing_devices_without_mma.push_back({ id, device_name });
} else if (device_name == "NVIDIA GeForce MX550") {
turing_devices_without_mma.push_back({ id, device_name });
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
turing_devices_without_mma.push_back({ id, device_name });
}
#endif // defined(GGML_USE_HIP)
}
if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
GGML_LOG_INFO(
" Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
}
GGML_LOG_INFO(
"Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
}
for (int id = 0; id < info.device_count; ++id) {
+177 -47
View File
@@ -3,6 +3,140 @@
#include <vector>
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct mmq_ids_helper_store {
uint32_t data;
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
data = (it & 0x003FFFFF) | (iex_used << 22);
}
__device__ uint32_t it() const {
return data & 0x003FFFFF;
}
__device__ uint32_t iex_used() const {
return data >> 22;
}
};
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template <int n_expert_used_template>
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
const int expert = blockIdx.x;
extern __shared__ char data_mmq_ids_helper[];
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
int nex_prev = 0; // Number of columns for experts with a lower index.
int it_compact = 0; // Running index for the compact slice of this expert.
if constexpr (n_expert_used_template == 0) {
// Generic implementation:
for (int it = 0; it < n_tokens; ++it) {
int iex_used = -1; // The index at which the expert is used, if any.
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
const int expert_used = ids[it*si1 + iex];
nex_prev += expert_used < expert;
if (expert_used == expert) {
iex_used = iex;
}
}
if (iex_used != -1) {
store[it_compact] = mmq_ids_helper_store(it, iex_used);
}
if (warp_reduce_any<warp_size>(iex_used != -1)) {
it_compact++;
}
}
} else {
// Implementation optimized for specific numbers of experts used:
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
const int it = it0 + threadIdx.x / neu_padded;
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
ids[it*si1 + iex] : INT_MAX;
const int iex_used = expert_used == expert ? iex : -1;
nex_prev += expert_used < expert;
// Whether the threads at this token position have used the expert:
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int it_compact_add_lower = 0;
#pragma unroll
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
if (threadIdx.x >= offset) {
it_compact_add_lower += tmp;
}
}
if (iex_used != -1) {
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
}
}
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
const mmq_ids_helper_store store_it = store[itc];
const int it = store_it.it();
const int iex_used = store_it.iex_used();
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
}
if (threadIdx.x != 0) {
return;
}
expert_bounds[expert] = nex_prev;
if (expert < gridDim.x - 1) {
return;
}
expert_bounds[gridDim.x] = nex_prev + it_compact;
}
template <int n_expert_used_template>
static void launch_mmq_ids_helper(
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
const int id = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
const dim3 num_blocks(n_experts, 1, 1);
const dim3 block_size(warp_size, 1, 1);
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
GGML_ASSERT(nbytes_shared <= smpbo);
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
}
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q4_0:
@@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
ne00, ne01, ne1, s01, ne11, s1,
ne02, ne12, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k};
use_stream_k, ne1};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
return;
}
@@ -148,54 +282,50 @@ void ggml_cuda_mul_mat_q(
const int64_t n_expert_used = ids->ne[0];
const int64_t ne_get_rows = ne12 * n_expert_used;
GGML_ASSERT(ne1 == n_expert_used);
std::vector<char> ids_host(ggml_nbytes(ids));
std::vector<int32_t> ids_src1_host;
ids_src1_host.reserve(ne_get_rows);
std::vector<int32_t> ids_dst_host;
ids_dst_host.reserve(ne_get_rows);
std::vector<int32_t> tokens_per_expert_host(ne02);
std::vector<int32_t> expert_bounds_host(ne02 + 1);
ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
{
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
const int si1 = ids->nb[1] / ggml_element_size(ids);
const int sis1 = nb12 / nb11;
for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
for (int64_t iex = 0; iex < n_expert_used; ++iex) {
const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
assert(expert_to_use >= 0 && expert_to_use < ne02);
if (expert_to_use == i02) {
ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
ids_dst_host.push_back(i12*ne1 + iex);
tokens_per_expert_host[i02]++;
break;
}
}
switch (n_expert_used) {
case 2:
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 4:
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 6:
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 8:
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 16:
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
case 32:
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
default:
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
break;
}
CUDA_CHECK(cudaGetLastError());
}
int32_t cumsum = 0;
for (int64_t i = 0; i < ne02; ++i) {
expert_bounds_host[i] = cumsum;
cumsum += tokens_per_expert_host[i];
}
expert_bounds_host[ne02] = cumsum;
std::vector<int32_t> ids_buf_host;
ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
const int32_t * ids_src1_dev = ids_buf_dev.ptr;
const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
@@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[2] / ts_src1;
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
CUDA_CHECK(cudaGetLastError());
}
@@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
const mmq_args args = {
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
ne02, ne02, s02, s12, s2,
ne03, ne13, s03, s13, s3,
use_stream_k};
use_stream_k, ne12};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
}
@@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
1, 1, 0, 0, 0,
1, 1, 0, 0, 0,
use_stream_k};
use_stream_k, src1_ncols};
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
+21 -13
View File
@@ -3138,7 +3138,8 @@ static __global__ void mul_mat_q(
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const int ncols_max) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -3152,7 +3153,7 @@ static __global__ void mul_mat_q(
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int mmq_y = get_mmq_y_device();
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
// Initialize the ids for writing back data with just the index.
@@ -3376,7 +3377,8 @@ template <ggml_type type, int mmq_x, bool need_check>
static __global__ void mul_mat_q_stream_k_fixup(
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@@ -3387,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
const int bidx0 = blockIdx.x;
@@ -3528,7 +3530,7 @@ struct mmq_args {
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
bool use_stream_k;
bool use_stream_k; int64_t ncols_max;
};
template<ggml_type type>
@@ -3558,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
const int ntzw = args.nchannels_y * args.nsamples_y;
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
@@ -3574,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
args.ncols_max);
}
return;
}
@@ -3601,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
args.ncols_max);
if (!fixup_needed) {
return;
@@ -3609,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
args.ncols_max);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
args.ncols_max);
if (!fixup_needed) {
return;
@@ -3624,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
args.ncols_max);
}
}
@@ -3649,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
continue;
}
const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
if (ntiles_x < ntiles_x_best) {
mmq_x_best = mmq_x;
+52
View File
@@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
// q4 contains 8 indices with 4 bit each.
// This function selects those bytes from table that are at those indices and returns them as int2.
// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
#if defined(GGML_USE_HIP)
// Load the 16-byte table into four 32-bit unsigned integers.
const uint32_t *values = (const uint32_t *)table;
const uint32_t q_even = q4;
const uint32_t q_odd = (q4 >> 4);
// Perform lookups in the lower half of the table (indices 0-7).
uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
// Perform lookups in the upper half of the table (indices 8-15).
uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
// Select between the low and high results based on the MSB of each index nibble.
uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
return make_int2(res_x, res_y);
#elif !defined(GGML_USE_MUSA)
// CUDA does not have an instruction for selecting bytes with 4 bit indices.
// However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
const uint32_t * table32 = (const uint32_t *) table;
// __byte_perm selects bytes based on the lower 16 bits in its third argument.
// Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
// To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
// Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
uint32_t tmp[2];
const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
#pragma unroll
for (uint32_t i = 0; i < 2; ++i) {
const uint32_t shift = 16 * i;
const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
}
// tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
// However, for the result we need ints with all even/odd 4 bit indices in q4.
// Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
#else
// Generic implementation.
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
@@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
#endif
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
+3
View File
@@ -22,7 +22,10 @@
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define __all_sync(mask, var) __all(var)
#define __any_sync(mask, var) __any(var)
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
+11 -20
View File
@@ -320,40 +320,31 @@ typedef struct {
} ggml_metal_kargs_mul_mv_ext;
typedef struct {
int32_t ne02;
int32_t ne10;
int32_t ne11; // n_expert_used (bcast)
uint64_t nb11;
uint64_t nb12;
int32_t neh11; // n_tokens
uint64_t nbh11;
int32_t ne21; // n_tokens
int32_t ne20; // n_expert_used
uint64_t nb21;
} ggml_metal_kargs_mul_mm_id_map0;
typedef struct {
int32_t ne20; // n_expert_used
int32_t neh0;
int32_t neh1;
uint64_t nbh1;
uint64_t nbh2;
int32_t ne0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_mul_mm_id_map1;
typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t neh12;
uint64_t nbh10;
uint64_t nbh11;
uint64_t nbh12;
uint64_t nbh13;
int32_t neh0;
int32_t neh1;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne20;
int32_t ne21;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm_id;
+126 -109
View File
@@ -93,35 +93,37 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
if (ctx->mtl_device) {
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
#endif
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
#if defined(GGML_METAL_USE_BF16)
ctx->use_bfloat = ctx->has_bfloat;
ctx->use_bfloat = ctx->has_bfloat;
#else
ctx->use_bfloat = false;
ctx->use_bfloat = false;
#endif
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
{
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
ctx->debug_fusion = val ? atoi(val) : 0;
{
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
ctx->debug_fusion = val ? atoi(val) : 0;
}
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
ctx->max_size = ctx->mtl_device.maxBufferLength;
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
}
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
ctx->max_size = ctx->mtl_device.maxBufferLength;
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
}
ctx->mtl_device_ref_count++;
@@ -396,8 +398,12 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -443,6 +449,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
@@ -452,6 +459,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@@ -461,6 +469,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -470,6 +479,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@@ -479,6 +489,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@@ -488,6 +499,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@@ -497,6 +509,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@@ -506,6 +519,13 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
@@ -1412,8 +1432,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
@@ -1459,6 +1483,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
@@ -1468,6 +1493,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
@@ -1477,6 +1503,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -1486,6 +1513,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
@@ -1495,6 +1523,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
@@ -1504,6 +1533,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
@@ -1513,6 +1543,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
@@ -1522,6 +1553,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
@@ -1846,7 +1884,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_ROPE:
return true;
case GGML_OP_IM2COL:
return op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
case GGML_OP_POOL_1D:
return false;
case GGML_OP_UPSCALE:
@@ -3878,38 +3916,6 @@ static int ggml_metal_encode_node(
default: break;
}
const int64_t neh10 = ne10; // n_embd
const int64_t neh11 = ne21; // n_tokens
const int64_t neh12 = ne02; // n_expert
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
const uint64_t nbh11 = nbh10*neh10;
const uint64_t nbh12 = nbh11*neh11;
const uint64_t nbh13 = nbh12*neh12;
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
if (!h_src1) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
return 0;
}
const int64_t neh0 = ne0;
const int64_t neh1 = ne21;
const int64_t neh2 = ne02;
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
const uint64_t nbh1 = nbh0*neh0;
const uint64_t nbh2 = nbh1*neh1;
//const uint64_t nbh3 = nbh2*neh2;
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
if (!h_dst) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
return 0;
}
// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -3919,8 +3925,8 @@ static int ggml_metal_encode_node(
}
// id map
// [n_expert_used, n_tokens]
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
// [n_tokens, n_expert]
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
@@ -3928,32 +3934,45 @@ static int ggml_metal_encode_node(
}
{
const int nth = MIN(1024, ne10/4);
ggml_metal_kargs_mul_mm_id_map0 args = {
ne02,
ne10,
ne11, // n_expert_used (bcast)
ne11, // n_expert_used (bcast)
nb11,
nb12,
neh11, // n_tokens
nbh11,
ne20, // n_expert_used
ne21, // n_tokens
ne20, // n_expert_used
nb21,
};
id<MTLComputePipelineState> pipeline = nil;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
pipeline = nil;
switch (ne20) {
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
}
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
const size_t smem = ne02*ne20*sizeof(uint16_t);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer: h_src1 offset:0 atIndex:3];
[encoder setBuffer: h_tpe offset:0 atIndex:4];
[encoder setBuffer: h_ids offset:0 atIndex:5];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
[encoder setBuffer: h_tpe offset:0 atIndex:2];
[encoder setBuffer: h_ids offset:0 atIndex:3];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
}
{
@@ -3992,13 +4011,15 @@ static int ggml_metal_encode_node(
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.neh12 =*/ neh12,
/*.nbh10 =*/ nbh10,
/*.nbh11 =*/ nbh11,
/*.nbh12 =*/ nbh12,
/*.nbh13 =*/ nbh13,
/*.neh0 =*/ neh0,
/*.neh1 =*/ neh1,
/*.ne11 =*/ ne11, // n_expert_used (bcast)
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne20 =*/ ne20, // n_expert_used
/*.ne21 =*/ ne21, // n_tokens
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
@@ -4006,42 +4027,14 @@ static int ggml_metal_encode_node(
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer: h_src1 offset:0 atIndex:2];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer: h_tpe offset:0 atIndex:3];
[encoder setBuffer: h_dst offset:0 atIndex:4];
[encoder setBuffer: h_ids offset:0 atIndex:4];
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
{
GGML_ASSERT(ne0 % 4 == 0);
const int nth = MIN(1024, ne0/4);
ggml_metal_kargs_mul_mm_id_map1 args = {
ne20, // n_expert_used
neh0,
neh1,
nbh1,
nbh2,
ne0,
nb1,
nb2,
};
id<MTLComputePipelineState> pipeline = nil;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer: h_dst offset:0 atIndex:1];
[encoder setBuffer: h_ids offset:0 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} else {
id<MTLComputePipelineState> pipeline = nil;
@@ -4701,7 +4694,6 @@ static int ggml_metal_encode_node(
} break;
case GGML_OP_IM2COL:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -5130,6 +5122,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
@@ -5154,6 +5147,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
@@ -5178,6 +5172,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
@@ -5202,6 +5197,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
@@ -5226,6 +5222,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
@@ -5250,6 +5247,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
@@ -5274,6 +5272,7 @@ static int ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
} else {
switch (ne00) {
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
@@ -5301,6 +5300,24 @@ static int ggml_metal_encode_node(
use_vec_kernel = true;
switch (ne00) {
case 40:
{
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
case 64:
{
switch (src1->type) {
+133 -119
View File
@@ -974,9 +974,16 @@ kernel void kernel_mul(
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
if (args.ne10 == 1) {
const float x = *((device float *)(src1_ptr));
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
}
}
}
@@ -1000,9 +1007,16 @@ kernel void kernel_div(
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
if (args.ne10 == 1) {
const float x = 1.0f / *((device float *)(src1_ptr));
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
}
}
}
@@ -4663,6 +4677,7 @@ kernel void kernel_flash_attn_ext(
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
@@ -4674,6 +4689,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
@@ -4685,6 +4701,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif
template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
@@ -4695,6 +4712,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
@@ -4705,6 +4723,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
@@ -4715,6 +4734,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
@@ -4725,6 +4745,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
@@ -5115,6 +5136,16 @@ kernel void kernel_flash_attn_ext_vec(
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 40, 40, 8>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 40, 40, 8>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
@@ -7474,97 +7505,81 @@ kernel void kernel_mul_mm(
}
}
template<typename T4>
template<short ne20> // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
device const char * src1,
device const char * src2,
device char * hsrc1,
device char * htpe,
device char * hids,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ide = tgpig[0]; // expert id
threadgroup char * shmem [[threadgroup(0)]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const short ide = tpitg; // expert id
int n_all = 0;
uint32_t n_all = 0;
device int32_t * ids_i32 = (device int32_t *) (hids);
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
if (i21 + tpitg < args.ne21) {
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
if (src2_i32[i20] != ide) {
continue;
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sids[i20] = src2_i32[i20];
}
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
}
if (tpitg.x == 0) {
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
}
++n_all;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short t = 0; t < ntg; t++) {
if (i21 + t >= args.ne21) {
break;
}
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
short sel = 0;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sel += (sids[i20] == ide)*(i20 + 1);
}
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
n_all += sel > 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (tpitg.x == 0) {
device int32_t * tpe_i32 = (device int32_t *) (htpe);
tpe_i32[ide] = n_all;
}
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
tpe_u32[ide] = n_all;
}
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
template<typename T>
kernel void kernel_mul_mm_id_map1(
constant ggml_metal_kargs_mul_mm_id_map1 & args,
device const char * hdst,
device const char * hids,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i20 = tgpig[0]; // used expert
const int i21 = tgpig[1]; // token
device const int32_t * ids_i32 = (device const int32_t *) (hids);
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
const int id = ids_i32[i21*args.ne20 + i20];
const int ide = id / args.neh1;
const int idt = id % args.neh1;
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
dst_f32x4[i0] = hdst_f32x4[i0];
}
}
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
device const char * tpe,
device const char * htpe,
device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shmem);
@@ -7572,19 +7587,20 @@ kernel void kernel_mul_mm_id(
const int r0 = tgpig.y;
const int r1 = tgpig.x;
const int im = tgpig.z;
const int im = tgpig.z; // expert
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
device const int32_t * ids_i32 = (device const int32_t *) (hids);
const int neh1 = tpe_i32[im];
const int32_t neh1 = tpe_u32[im];
if (r1*BLOCK_SIZE_N >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
@@ -7600,20 +7616,23 @@ kernel void kernel_mul_mm_id(
short il = (tiitg % THREAD_PER_ROW);
const int i12 = im%args.neh12;
const int i13 = im/args.neh12;
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const short i11 = (id % args.ne20) % args.ne11;
const short i12 = (id / args.ne20);
const short i13 = 0;
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
device const half * y = (device const half *)(src1
+ args.nbh13*i13
+ args.nbh12*i12
+ args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
+ args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
device const float * y = (device const float *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*i11
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
@@ -7629,7 +7648,7 @@ kernel void kernel_mul_mm_id(
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -7665,43 +7684,38 @@ kernel void kernel_mul_mm_id(
}
}
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
device float * C = (device float *) dst +
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
#pragma unroll(8)
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short j = sgitg; j < n_cols; j += 4) {
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
const short ide = id % args.ne20;
const short idt = id / args.ne20;
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = tiisg;
for (; i < n_rows/4; i += 32) {
*(D4 + i) = *(C4 + i);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = 0;
for (; i < n_rows/4; i++) {
*(D4 + i) = *(C4 + i);
}
i *= 4;
for (; i < n_rows; i++) {
*(D + i) = *(C + i);
}
}
i = (4*(n_rows/4)) + tiisg;
for (; i < n_rows; i += 32) {
*(D + i) = *(C + i);
}
}
}
+2 -1
View File
@@ -2647,8 +2647,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_RMS_NORM:
return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_REPEAT:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD:
+253 -213
View File
@@ -388,6 +388,7 @@ struct vk_device_struct {
bool float_controls_rte_fp16;
bool subgroup_add;
bool subgroup_shuffle;
bool subgroup_ballot;
bool multi_add;
bool add_rms_fusion;
@@ -1044,7 +1045,7 @@ struct vk_op_sum_rows_push_constants
uint32_t ne0_1mp, ne0_1L;
};
vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
uint32_t type_size = (uint32_t)ggml_type_size(src->type);
vk_op_sum_rows_push_constants p = {};
p.n_cols = (uint32_t)n_cols;
@@ -2089,10 +2090,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t warps = warptile[0] / warptile[10];
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
@@ -2176,8 +2178,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u);
const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u);
const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);
const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
(device->subgroup_size_control && device->subgroup_max_size >= 16);
// mulmat
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
l_warptile_id, m_warptile_id, s_warptile_id,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
@@ -2248,9 +2259,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@@ -2276,14 +2296,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
// Disable mul_mat_id if not enough shared memory is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
device->mul_mat_id_s[i] = false;
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
device->mul_mat_id_l[i] = false;
}
}
@@ -2461,32 +2481,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
GGML_ASSERT(device->subgroup_ballot);
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
}
#endif
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
@@ -2573,55 +2595,56 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
GGML_ASSERT(device->subgroup_ballot);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
}
#endif
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MM
} else
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
@@ -2638,38 +2661,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
} \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2681,51 +2704,77 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
}
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
} else {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
@@ -2735,34 +2784,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2774,33 +2823,59 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
}
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
@@ -2817,8 +2892,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_wg_denoms = { 64, 64, 1 };
s_wg_denoms = { 32, 32, 1 };
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
}
#undef CREATE_MM
@@ -3506,6 +3581,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -3655,9 +3733,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
(subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
subgroup_size_control_features.subgroupSizeControl;
if (device->subgroup_size_control) {
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
}
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
#if defined(VK_KHR_cooperative_matrix)
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
@@ -6213,7 +6289,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
GGML_ASSERT(nei0 * nei1 <= 4096);
const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2];
@@ -6653,37 +6728,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
} else {
// Split based on number of ids, to fit in shared memory
const uint32_t nei0 = (uint32_t)src2->ne[0];
const uint32_t nei1 = (uint32_t)src2->ne[1];
GGML_ASSERT(nei0 <= 4096);
const uint32_t split_size = std::min(nei1, 4096u / nei0);
if (split_size == nei1) {
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
} else {
ggml_tensor src1_copy = *src1;
ggml_tensor src2_copy = *src2;
ggml_tensor dst_copy = *dst;
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
src1_copy.ne[2] = n_tokens;
src2_copy.ne[1] = n_tokens;
dst_copy.ne[2] = n_tokens;
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
// invalidate cached prealloc_y, can't cache based on the copy of the ggml_tensor
ctx->prealloc_y_last_pipeline_used = {};
ctx->prealloc_y_last_tensor_used = nullptr;
}
}
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
}
}
@@ -10194,12 +10239,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
if (need_sync) {
VK_LOG_DEBUG("node_idx=" << i << " sync");
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
} else {
VK_LOG_DEBUG("node_idx=" << i << " unsynced");
}
// Add the last fused node and all fused source nodes to the unsynchronized list.
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
@@ -12241,7 +12283,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_CONCAT) {
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
} else if (tensor->op == GGML_OP_UPSCALE) {
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
} else if (tensor->op == GGML_OP_SCALE) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
@@ -12480,11 +12522,9 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
return;
}
bool fused_rms_norm_mul = false;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
+27 -17
View File
@@ -17,6 +17,9 @@
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#endif
#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
@@ -106,11 +109,13 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef COOPMAT
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
@@ -160,15 +165,18 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
#endif
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID
#ifdef COOPMAT
@@ -235,18 +243,20 @@ void main() {
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
#ifdef COOPMAT
#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true);
load_row_ids(expert_idx, true, ic);
} else {
load_row_ids(expert_idx, false);
load_row_ids(expert_idx, false, ic);
}
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u16vec2(ii0, ii1);
if (_ne1 >= ic * BN) {
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
}
_ne1++;
}
}
@@ -792,7 +802,7 @@ void main() {
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
#if LOAD_VEC_B == 8
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -808,7 +818,7 @@ void main() {
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC_B == 4
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -827,7 +837,7 @@ void main() {
#else
const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1 && block + loadr_b < end_k) {
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[loadc_b + l];
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
@@ -898,7 +908,7 @@ void main() {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[row_i - ic * BN];
if (dr + cm_row * TM + store_r < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
@@ -948,7 +958,7 @@ void main() {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID
@@ -93,7 +93,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
shared u16vec4 row_ids[4096];
shared u16vec4 row_ids[BN];
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
B_TYPE b[];
@@ -111,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
return B_TYPE(0.0);
}
const u16vec4 row_idx = row_ids[row_i];
const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
return ret;
@@ -123,14 +123,14 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
uint dc = ic * BN + c;
if (dr < p.M && dc < _ne1) {
uint row_i = dc;
uint row_i = c;
const u16vec4 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
}
return elem;
}
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
@@ -180,11 +180,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
@@ -218,9 +221,9 @@ void main() {
#ifdef MUL_MAT_ID
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true);
load_row_ids(expert_idx, true, ic);
} else {
load_row_ids(expert_idx, false);
load_row_ids(expert_idx, false, ic);
}
// Workgroup has no work
@@ -68,6 +68,12 @@ const std::vector<std::string> type_names = {
"bf16",
};
enum MatMulIdType {
NONE,
DEFAULT,
SUBGROUP,
};
namespace {
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
@@ -293,7 +299,7 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
}
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
@@ -303,9 +309,13 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
};
std::string shader_name = "matmul";
if (matmul_id) {
if (matmul_id_type == MatMulIdType::DEFAULT) {
base_dict["MUL_MAT_ID"] = "1";
shader_name = "matmul_id";
} else if (matmul_id_type == MatMulIdType::SUBGROUP) {
base_dict["MUL_MAT_ID"] = "1";
base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
shader_name = "matmul_id_subgroup";
}
if (fp16) {
@@ -389,7 +399,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
@@ -401,26 +411,28 @@ void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
// matmul
for (const auto& matmul_id : {false, true}) {
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
// No coopmats
// fp32
matmul_shaders(false, matmul_id, false, false, false);
matmul_shaders(false, matmul_id_type, false, false, false);
// fp16, fp32acc and fp16acc
matmul_shaders(true, matmul_id, false, false, false);
matmul_shaders(true, matmul_id, false, false, true);
matmul_shaders(true, matmul_id_type, false, false, false);
matmul_shaders(true, matmul_id_type, false, false, true);
if (matmul_id_type != MatMulIdType::DEFAULT) {
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
// Coopmat, fp32acc and fp16acc
matmul_shaders(true, matmul_id, true, false, false);
matmul_shaders(true, matmul_id, true, false, true);
// Coopmat, fp32acc and fp16acc
matmul_shaders(true, matmul_id_type, true, false, false);
matmul_shaders(true, matmul_id_type, true, false, true);
#endif
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
// Coopmat2, fp32acc and fp16acc
matmul_shaders(true, matmul_id, false, true, false);
matmul_shaders(true, matmul_id, false, true, true);
// Coopmat2, fp32acc and fp16acc
matmul_shaders(true, matmul_id_type, false, true, false);
matmul_shaders(true, matmul_id_type, false, true, true);
#endif
}
}
// flash attention
-1
View File
@@ -427,7 +427,6 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
"model.layers.{bid}.block_sparse_moe.gate", # smallthinker
"model.transformer.blocks.{bid}.ff_proj", # llada
"layers.{bid}.mlp.gate_proj", # qwen3-embedding
),
+12 -10
View File
@@ -280,7 +280,7 @@ llama_context::llama_context(
}
// reserve worst-case graph
if (!hparams.vocab_only && memory) {
if (!hparams.vocab_only) {
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
@@ -292,11 +292,13 @@ llama_context::llama_context(
int n_splits_tg = -1;
int n_nodes_tg = -1;
// simulate full KV cache
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize KV cache");
llama_memory_context_ptr mctx;
if (memory) {
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory module");
}
}
cross.v_embd.clear();
@@ -1056,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
llama_pos pos_min[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
@@ -1073,7 +1075,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
continue;
}
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
memory->seq_rm(s, pos_min[s], -1);
}
@@ -1857,7 +1859,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
}
if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
memory->state_write(io);
}
@@ -1943,7 +1945,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
}
if (memory) {
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
memory->state_read(io);
}
+25
View File
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
GGML_ABORT("fatal error");
}
bool llama_hparams::has_kv(uint32_t il) const {
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
return true;
}
return false;
}
// by default, all layers have kv
return true;
}
uint32_t llama_hparams::n_layer_kv() const {
uint32_t res = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (has_kv(il)) {
res++;
}
}
return res;
}
+6
View File
@@ -41,6 +41,7 @@ struct llama_hparams {
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@@ -221,6 +222,11 @@ struct llama_hparams {
uint32_t n_pos_per_embd() const;
bool is_swa(uint32_t il) const;
bool has_kv(uint32_t il) const;
// number of layers for which has_kv() returns true
uint32_t n_layer_kv() const;
};
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
+24 -7
View File
@@ -22,9 +22,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}
return !model.hparams.is_swa(il);
};
const layer_filter_cb filter_swa = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}
return model.hparams.is_swa(il);
};
const uint32_t size_base = kv_size;
@@ -41,16 +58,16 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
kv_base = std::make_unique<llama_kv_cache>(
model, std::move(filter_base), type_k, type_v,
model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache>(
model, std::move(filter_swa), type_k, type_v,
model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
}
void llama_kv_cache_iswa::clear(bool data) {
+4 -2
View File
@@ -20,11 +20,13 @@ public:
bool v_trans,
bool offload,
bool swa_full,
bool ,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad);
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache_iswa() = default;
+35 -33
View File
@@ -17,32 +17,25 @@
//
llama_kv_cache::llama_kv_cache(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type) :
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
GGML_ASSERT(kv_size % n_pad == 0);
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
auto n_layer_cache = hparams.n_layer;
if (model.arch == LLM_ARCH_GEMMA3N) {
n_layer_cache = 20;
}
if (model.arch == LLM_ARCH_GLM4_MOE) {
// GLM-4.5: Only process up to last layer, skip final NextN layer
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
}
const uint32_t n_layer_kv = hparams.n_layer_kv();
// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -50,7 +43,7 @@ llama_kv_cache::llama_kv_cache(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
@@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
__func__, hparams.n_embd_v_gqa_max());
}
for (uint32_t il = 0; il < n_layer_cache; il++) {
for (uint32_t il = 0; il < hparams.n_layer; il++) {
if (!hparams.has_kv(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
continue;
}
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
continue;
}
@@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
layers.push_back({ il, k, v, k_stream, v_stream, });
}
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
if (model.arch == LLM_ARCH_GEMMA3N) {
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
if (reuse) {
LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
for (uint32_t il = 0; il < hparams.n_layer; il++) {
const int32_t il_reuse = reuse(il);
if (il_reuse < 0) {
LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
continue;
}
const bool is_swa = hparams.is_swa(il);
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
continue;
}
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
map_layer_ids[il] = map_layer_ids[il_reuse];
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
}
}
+13 -15
View File
@@ -21,9 +21,6 @@ class llama_kv_cache : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
@@ -82,18 +79,19 @@ public:
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type);
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache() = default;
+29 -28
View File
@@ -9,32 +9,29 @@
//
llama_memory_hybrid::llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn,
layer_filter_cb && filter_recr) :
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn,
const layer_filter_cb & filter_recr) :
hparams(model.hparams),
mem_attn(new llama_kv_cache(
model,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
type_k,
type_v,
v_trans,
@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max,
n_pad,
n_swa,
swa_type
swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
)),
mem_recr(new llama_memory_recurrent(
model,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr,
type_r,
type_s,
offload,
rs_size,
n_seq_max
n_seq_max,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
)) {}
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+18 -22
View File
@@ -18,31 +18,27 @@
class llama_memory_hybrid : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn = nullptr,
layer_filter_cb && filter_recr = nullptr);
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn = nullptr,
const layer_filter_cb & filter_recr = nullptr);
~llama_memory_hybrid() = default;
+7 -7
View File
@@ -16,13 +16,13 @@
//
llama_memory_recurrent::llama_memory_recurrent(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const llama_model & model,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;
head = 0;
+7 -11
View File
@@ -15,18 +15,14 @@
// see the implementation of llama_kv_cache_context_i for an example how to do it
class llama_memory_recurrent : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_recurrent(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max);
const llama_model & model,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
const layer_filter_cb & filter);
~llama_memory_recurrent() = default;
+8
View File
@@ -3,6 +3,7 @@
#include "llama.h"
#include <memory>
#include <functional>
struct llama_ubatch;
@@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
struct llama_memory_i {
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
// this callback is used to specify which layers should reuse memory from other layers
// return negative value to indicate that the layer il should not reuse memory
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
+27 -11
View File
@@ -1115,6 +1115,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(5);
hparams.n_layer_kv_from_start = 20;
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
hparams.f_attention_scale = 1.0f;
@@ -1474,12 +1475,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// Expert gating function (GLM-4.5 uses sigmoid)
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
}
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
@@ -10524,7 +10528,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
const int64_t n_embd_altup;
const int64_t n_altup;
const int i_altup_act;
const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
const int n_layer_sparsity = 10; // number of layers using activation sparsity
const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
@@ -10574,8 +10577,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
for (int il = 0; il < n_layer; ++il) {
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
const bool has_kv = (il < n_layer_kv);
const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
@@ -10595,7 +10596,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
// self-attention
if (has_kv) {
if (hparams.has_kv(il)) {
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
@@ -10635,7 +10636,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
} else {
// no KV layers
// reuse KV cache of earlier layers
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@@ -18256,12 +18257,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
if (llm_arch_is_recurrent(arch)) {
res = new llama_memory_recurrent(
*this,
nullptr,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max);
cparams.n_seq_max,
nullptr);
} else if (llm_arch_is_hybrid(arch)) {
const auto padding = llama_kv_cache::get_padding(cparams);
@@ -18302,6 +18303,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
llama_memory_i::layer_reuse_cb reuse = nullptr;
if (arch == LLM_ARCH_GEMMA3N) {
reuse = [&](int32_t il) {
if (il >= (int32_t) hparams.n_layer_kv_from_start) {
return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
}
return -1;
};
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
@@ -18316,13 +18329,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
n_ctx_per_stream,
cparams.n_seq_max,
cparams.n_ubatch,
padding);
padding,
nullptr,
reuse);
} else {
GGML_ASSERT(!hparams.is_swa_any());
res = new llama_kv_cache(
*this,
nullptr,
params.type_k,
params.type_v,
!cparams.flash_attn,
@@ -18332,7 +18346,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max,
padding,
hparams.n_swa,
hparams.swa_type);
hparams.swa_type,
nullptr,
nullptr);
}
}
}
+22
View File
@@ -2209,6 +2209,26 @@ struct test_count_equal : public test_case {
double max_nmse_err() override {
return 0.0;
}
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_F32) {
// initialize with unique values to avoid ties
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<float> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
}
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_REPEAT
@@ -5997,6 +6017,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// test large experts*tokens
for (bool b : {false, true}) {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 50, 200, 64));
}
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
+14 -3
View File
@@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
const int tg = n_tg[i_tg];
const int pl = n_pl[i_pl];
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
const int n_ctx_req = is_pp_shared ? (params.kv_unified ? pp : pl*pp) + pl*tg : pl*(pp + tg);
if (n_ctx_req > n_kv_max) {
continue;
@@ -147,13 +147,24 @@ int main(int argc, char ** argv) {
return 1;
}
const auto t_pp_end = ggml_time_us();
if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_memory_seq_cp(mem, 0, i, -1, -1);
}
}
const auto t_pp_end = ggml_time_us();
if (!params.kv_unified) {
// run one dummy token to apply the memory copy
common_batch_clear(batch);
common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
llama_memory_seq_rm(mem, 0, pp, -1);
}
}
const auto t_tg_start = ggml_time_us();
+5
View File
@@ -2202,6 +2202,8 @@ struct clip_model_loader {
hparams.minicpmv_query_num = 64;
} else if (hparams.minicpmv_version == 5) {
hparams.minicpmv_query_num = 64;
} else if (hparams.minicpmv_version == 6) {
hparams.minicpmv_query_num = 64;
} else {
hparams.minicpmv_query_num = 96;
}
@@ -3685,6 +3687,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} else if (params.minicpmv_version == 5) {
// MiniCPM-V 4.0
n_patches = 64;
} else if (params.minicpmv_version == 6) {
// MiniCPM-V 4.5
n_patches = 64;
} else {
GGML_ABORT("Unknown minicpmv version");
}
@@ -607,6 +607,9 @@ else:
elif minicpmv_version == 5:
emb_dim = 2560
block_count = 27
elif minicpmv_version == 6:
emb_dim = 4096
block_count = 27
default_vision_config = {
"hidden_size": 1152,
@@ -630,6 +633,10 @@ elif minicpmv_version == 5:
default_vision_config["model_type"] = "siglip_vision_model"
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
elif minicpmv_version == 6:
default_vision_config["model_type"] = "siglip_vision_model"
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
processor = None
# if model.attn_pool is not None:
+1 -1
View File
@@ -207,7 +207,7 @@ struct mtmd_context {
tok_row_end_trail = false; // no trailing end-of-row token
ov_img_first = true;
} else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5) {
} else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5 || minicpmv_version == 6) {
// minicpmv 2.6 format:
// <image> (overview) </image><slice> (slice) </slice><slice> (slice) </slice>\n ...
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6;