Compare commits

..

49 Commits

Author SHA1 Message Date
rmatif 60a7658810 opencl: allow mixed f16/f32 add (#15140) 2025-08-12 02:42:41 -07:00
Aman Gupta efe3a90996 CUDA cmake: add -lineinfo for easier debug (#15260) 2025-08-12 17:21:45 +08:00
Chenguang Li bbd57b7eaf CANN: GGML_OP_CPY optimization (#15070)
Signed-off-by: noemotiovon <757486878@qq.com>
2025-08-12 16:12:13 +08:00
R0CKSTAR 25ff6f7659 musa: fix failures in test-backend-ops for mul_mat_id op (#15236)
* musa: fix failures in test-backend-ops for mul_mat_id op

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* Address review comments

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

---------

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2025-08-12 10:02:51 +08:00
hipudding be48528b06 CANN: Add broadcast for softmax and FA (#15208)
* refactor softmax

* fix fa

* fix mask shape

* format

* add comments

* Remove whitespace
2025-08-11 22:50:31 +08:00
rainred cf9e5648a7 mtmd : Fix MinicpmV model converter and clip to avoid using hardcode. (#14750)
* Fix MinicpmV model converter and clip to avoid using hardcode.

* Code update for pr/14750

* Remove unused field, update script path in docs.

* Add version 5 for fallback code.

---------

Co-authored-by: lzhang <zhanglei@modelbest.cn>
2025-08-11 16:12:12 +02:00
Xuan-Son Nguyen fba5c0d680 chat : hotfix gpt-oss jinja raising an exception (#15243)
* chat : hotfix gpt-oss jinja raising an exception

* fix
2025-08-11 15:31:35 +02:00
Xuan-Son Nguyen 53d0a12658 server : allow specifying reasoning_format in HTTP request (#15238) 2025-08-11 14:48:41 +02:00
Zagaj 27093afe78 readme : update infra list (#15234) 2025-08-11 15:27:54 +03:00
Georgi Gerganov 228f724d9c kv-cache : fix seq_rm with seq_id == -1 (#15226)
* kv-cache : fix seq_rm with seq_id == -1

ggml-ci

* cont : iterate over streams

ggml-ci
2025-08-11 13:58:24 +03:00
Daniel Bevenius cd3069dfcb kv-cache : log (debug) all streams in find_slot (#15176)
This commit updates `llama_kv_cache_unified::find_slot` to log
information for all streams when debug is enabled.

The motivation for this change is that currently if a non-unified
kv-cache is used, then only one stream will be logged because the
code was currently uses `seq_to_stream[1]`.
2025-08-11 11:21:19 +02:00
Sigbjørn Skjæret 50e81bdf5d convert : fix merge conflicts (#15229) 2025-08-11 11:15:44 +02:00
Daniel Bevenius 1ebbaddff2 perplexity : update comments/error msg to use decode [no ci] (#15227)
This commit updates comments and error messages to use "decode" instead
of "eval" in perplexity.cpp.

The motivation for this is that `llama_eval` was renamed to
`llama_decode` a while ago, but the comments and error messages
still referred to "eval". This change ensures consistency and clarity.
2025-08-11 11:21:24 +03:00
Julien Denize a3a7874272 convert : improve Mistral models integration (#14737)
* Improve Mistral models integration with llama.cpp

* Revert changes and fix gguf

* Revert change

* refactor convert_mistral_to_gguf.py in convert_hf_to_gguf.py

* Revert collateral

* Rename model name

* refactor

* revert

* remove duplicate

* Remove duplication code

* Fixes

* Fix flake issues

* Apply comments

* Apply comments

* Apply comments

* Fix remote

* add default chat template

* Revert

* nit
2025-08-11 10:07:49 +02:00
Charles Xu 002cb1bb33 kleidiai: fix unsigned overflow bug (#15150)
* kleidiai: fix unsigned overflow bug

* address review comments
2025-08-11 09:59:26 +02:00
David Zhao 79c1160b07 cuda: refactored ssm_scan and use CUB (#13291)
* cuda: refactored ssm_scan to use CUB

* fixed compilation error when when not using CUB

* assign L to constant and use size_t instead of int

* deduplicated functions

* change min blocks per mp to 1

* Use cub load and store warp transpose

* suppress clang warning
2025-08-09 20:29:43 +02:00
Aman Gupta 34c9d765bf CUDA: add attention sinks for tile and wmma (#15178)
* CUDA: add attention sinks for tile and wmma

* Review: formatting changes + remove syncthreads from tile + remove warp_reduce_max from wmma
2025-08-09 20:00:24 +08:00
compilade e54d41befc gguf-py : add Numpy MXFP4 de/quantization support (#15111)
* gguf-py : add MXFP4 de/quantization support

* ggml-quants : handle zero amax for MXFP4
2025-08-08 17:48:26 -04:00
Johannes Gäßler 4850b52aed server-bench: external OAI servers, sqlite (#15179)
* server-bench: external OAI servers, sqlite

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* raise_for_status

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-08-08 23:04:36 +02:00
AN Long cd6983d56d ggml : fix field name when new ggml_backend (#14944) 2025-08-08 14:37:22 +02:00
Olivier Chafik 6c7e9a5440 vendor: sync minja (#15161)
* vendor: sync minja

* Update minja.hpp

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-08-08 10:45:18 +01:00
Johannes Gäßler 1425f587a8 CUDA: attention sinks for mma FlashAttention (#15157) 2025-08-08 08:19:58 +02:00
lhez aaa3d07ae7 opencl: support sink in soft_max (attn sinks) (#15152) 2025-08-07 21:47:03 -07:00
Xuan-Son Nguyen 50aa938901 convert : support non-mxfp4 HF model (#15153)
* convert : support non-mxfp4 HF model

* rm redundant check

* disable debug check
2025-08-07 23:26:03 +02:00
Jeff Bolz c4f53563df vulkan: support fattn sinks (#15126) 2025-08-07 22:44:20 +02:00
Jeff Bolz a0552c8bee vulkan: Add env var to disable host visible vidmem (#15109) 2025-08-07 22:07:11 +02:00
RunningLeon 99acbc9921 llama : Support intern-s1 (#14875)
* support internvl

* support interns1

* resolve comments

* put interns1 in tensor mapping

* resolve comment

* move tokenizer changes to sub class
2025-08-07 18:20:40 +02:00
uvos 7ad67ba9fe HIP: add cmake option to enable compiler output of kernel resource usage metrics (#15103) 2025-08-07 16:44:14 +02:00
Christian Kastner 9a96389544 ggml: Skip backend library linking code when GGML_BACKEND_DL=ON (#15094)
Any available libraries are found and loaded dynamically at runtime.
2025-08-07 13:45:41 +02:00
Johannes Gäßler 1d72c84188 CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 (#15131)
* CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16
2025-08-07 10:53:21 +02:00
Johannes Gäßler 20638e4f16 scripts: fix crash when --tool is not set (#15133) 2025-08-07 08:50:30 +02:00
Daniel Bevenius 36d3f00e14 requirements : fix PyTorch uint64 compatibility (#15134)
This commit addresses an issue with the convert_hf_to_gguf script
which is currently failing with:
```console
AttributeError: module 'torch' has no attribute 'uint64'
```

This occurred because safetensors expects torch.uint64 to be available
in the public API, but PyTorch 2.2.x only provides limited support for
unsigned types beyond uint8 it seems. The torch.uint64 dtype exists but
is not exposed in the standard torch namespace
(see pytorch/pytorch#58734).

PyTorch 2.4.0 properly exposes torch.uint64 in the public API, resolving
the compatibility issue with safetensors. This also required torchvision
to updated to =0.19.0 for compatibility.

Refs: https://huggingface.co/spaces/ggml-org/gguf-my-repo/discussions/186#68938de803e47d990aa087fb
Refs: https://github.com/pytorch/pytorch/issues/58734
2025-08-07 05:31:48 +02:00
Reese Levine 5fd160bbd9 ggml: Add basic SET_ROWS support in WebGPU (#15137)
* Begin work on set_rows

* Work on set rows

* Add error buffers for reporting unsupported SET_ROWS indices

* Remove extra comments
2025-08-06 15:14:40 -07:00
rmatif 756cfea826 fix profiling crash (#15072) 2025-08-06 14:17:51 -07:00
lhez e725a1a982 opencl: add swiglu_oai and add_id (#15121)
* opencl: add `swiglu-oai`

* opencl: add `add_id`

* opencl: add missing `add_id.cl`
2025-08-06 12:12:17 -07:00
Sachin Desai 3db4da56a5 chat : support Granite model reasoning and tool call (#14864) 2025-08-06 20:27:30 +02:00
Juk Armstrong 476aa3fd57 Fixed name -override-tensors to -override-tensor (#15129) 2025-08-06 17:28:48 +01:00
Diego Devesa 0d8831543c ggml : fix fallback to CPU for ununsupported ops (#15118) 2025-08-06 14:37:35 +02:00
Sigbjørn Skjæret 65c797c4fa chat : fix yandex chat template (#15116) 2025-08-06 13:26:49 +02:00
stevenkuang 25726898e8 chat : fix hunyuan auto-detection (#15114)
Signed-off-by: stevenkuang <stevenkuang@tencent.com>
2025-08-06 11:48:30 +02:00
Chenguang Li 2241453252 CANN: add support for ACL Graph (#15065)
* feat(cann): add optional support for ACL Graph execution

This commit adds support for executing ggml computational graphs using
Huawei's ACL graph mode via the USE_CANN_GRAPH flag. The support can be
enabled at compile time using the CMake option:

    -DUSE_CANN_GRAPH=ON

By default, ACL graph execution is **disabled**, and the fallback path
uses node-by-node execution.

Key additions:
- CMake option  to toggle graph mode
- Graph capture and execution logic using
- Tensor property matching to determine whether graph update is required
- Safe fallback and logging if the environment variable LLAMA_SET_ROWS
  is unset or invalid

This prepares the backend for performance improvements in repetitive graph
execution scenarios on Ascend devices.

Signed-off-by: noemotiovon <757486878@qq.com>

* Fix review comments

Signed-off-by: noemotiovon <757486878@qq.com>

* remane USE_CANN_GRAPH to USE_ACL_GRAPH

Signed-off-by: noemotiovon <757486878@qq.com>

* fix typo

Signed-off-by: noemotiovon <757486878@qq.com>

---------

Signed-off-by: noemotiovon <757486878@qq.com>
2025-08-06 14:12:42 +08:00
Reese Levine 9515c6131a ggml: WebGPU disable SET_ROWS for now (#15078)
* Add paramater buffer pool, batching of submissions, refactor command building/submission

* Add header for linux builds

* Free staged parameter buffers at once

* Format with clang-format

* Fix thread-safe implementation

* Use device implicit synchronization

* Update workflow to use custom release

* Remove testing branch workflow

* Disable set_rows until it's implemented

* Fix potential issue around empty queue submission

* Try synchronous submission

* Try waiting on all futures explicitly

* Add debug

* Add more debug messages

* Work on getting ssh access for debugging

* Debug on failure

* Disable other tests

* Remove extra if

* Try more locking

* maybe passes?

* test

* Some cleanups

* Restore build file

* Remove extra testing branch ci
2025-08-05 16:26:38 -07:00
Georgi Gerganov fd1234cb46 llama : add gpt-oss (#15091)
* oai moe

* compat with new checkpoint

* add attn sink impl

* add rope scaling yarn

* logits match with latest transformers code

* wip chat template

* rm trailing space

* use ggml_scale_bias

* rm redundant is_swa_all

* convert interleaved gate_up

* graph : fix activation function to match reference (#7)

* vocab : handle o200k_harmony special tokens

* ggml : add attention sinks support (#1)

* llama : add attn sinks

* ggml : add attn sinks

* cuda : add attn sinks

* vulkan : add support for sinks in softmax

remove unnecessary return

* ggml : add fused swiglu_oai op (#11)

* ggml : add fused swiglu_oai op

* Update ggml/src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* update CUDA impl

* cont : metal impl

* add vulkan impl

* test-backend-ops : more test cases, clean up

* llama : remove unfused impl

* remove extra lines

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>

* repack mxfp4 upon conversion

* clean up a bit

* enable thinking

* add quick hack to render only some special tokens

* fix bf16 conversion

* remove vocab hack

* webui ok

* support chat parsing for gpt-oss

* fix webui

* direct mapping mxfp4, FINALLY

* force using mxfp4

* properly use lazy tensor

* ggml : add mxfp4

ggml : use e8m0 conversion instead of powf

Co-authored-by: Diego Devesa <slarengh@gmail.com>

change kvalues_mxfp4 table to match e2m1 (#6)

metal : remove quantization for now (not used)

cuda : fix disabled CUDA graphs due to ffn moe bias

vulkan : add support for mxfp4

cont : add cm2 dequant

* ggml : add ggml_add_id (#13)

* ggml : add ggml_add_id

* add cuda impl

* llama : add weight support check for add_id

* perf opt

* add vulkan impl

* rename cuda files

* add metal impl

* allow in-place ggml_add_id

* llama : keep biases on CPU with --cpu-moe

* llama : fix compile error

ggml-ci

* cuda : add fallback for __nv_cvt_e8m0_to_bf16raw

ggml-ci

* cleanup

ggml-ci

* sycl : fix supports_op for MXFP4

ggml-ci

* fix Unknown reasoning format

* ggml-cpu : fix AVX build

ggml-ci

* fix hip build

ggml-ci

* cuda : add mxfp4 dequantization support for cuBLAS

ggml-ci

* ggml-cpu : fix mxfp4 fallback definitions for some architectures

ggml-ci

* cuda : fix version required for __nv_cvt_e8m0_to_bf16raw

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: slaren <slarengh@gmail.com>
2025-08-05 22:10:36 +03:00
Sigbjørn Skjæret f324a3b715 chat : only remove double bos/eos if added (#15086)
* only remove double bos/eos if added

* fix tests
2025-08-05 20:43:36 +02:00
Georgi Gerganov be42642581 readme : update hot topics (#15097) 2025-08-05 20:19:33 +03:00
Romain Biessy 3306ceabf0 sycl: fix mul_mat selection (#15092) 2025-08-05 18:39:55 +02:00
Juk Armstrong c81de6e107 Fix glm4moe bug (#15088) 2025-08-05 13:56:44 +01:00
Alex Wu 22f060c9c4 webui: fix markdown table (#15081)
* webui: fix markdown table

* webui: fix table display with themes
2025-08-05 13:56:44 +02:00
compilade ee3a9fcf88 context : fix index overflow on huge outputs (#15080)
* context : fix overflow when re-ordering huge outputs

* context : fix logits size overflow for huge batches
2025-08-05 11:27:45 +02:00
149 changed files with 6677 additions and 1617 deletions
+2 -1
View File
@@ -17,6 +17,7 @@ LLM inference in C/C++
## Hot topics
- Support for the `gpt-oss` model with native MXFP4 format has been added | [PR](https://github.com/ggml-org/llama.cpp/pull/15091) | [Collaboration with NVIDIA](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) | [Comment](https://github.com/ggml-org/llama.cpp/discussions/15095)
- Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen)
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
@@ -239,7 +240,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
<details>
<summary>Infrastructure</summary>
- [Paddler](https://github.com/distantmagic/paddler) - Stateful load balancer custom-tailored for llama.cpp
- [Paddler](https://github.com/intentee/paddler) - Open-source LLMOps platform for hosting and scaling AI in your own infrastructure
- [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs
- [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly
- [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server
+2 -5
View File
@@ -2947,12 +2947,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
"- none: leaves thoughts unparsed in `message.content`\n"
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
"(default: deepseek)",
"(default: auto)",
[](common_params & params, const std::string & value) {
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
else { throw std::invalid_argument("invalid value"); }
params.reasoning_format = common_reasoning_format_from_name(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
add_opt(common_arg(
+9 -1
View File
@@ -55,7 +55,15 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
std::string arguments = "";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else {
arguments = tool_call.at("arguments");
}
}
return add_tool_call(name, id, arguments);
}
+201 -2
View File
@@ -126,6 +126,8 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool add_bos;
bool add_eos;
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
@@ -143,6 +145,8 @@ struct templates_params {
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
json extra_context;
bool add_bos;
bool add_eos;
};
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -445,6 +449,8 @@ std::string common_chat_format_single(
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
inputs.add_bos = tmpls->add_bos;
inputs.add_eos = tmpls->add_eos;
std::string fmt_past_msg;
if (!past_msg.empty()) {
@@ -469,6 +475,8 @@ std::string common_chat_format_single(
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
inputs.add_bos = tmpls->add_bos;
inputs.add_eos = tmpls->add_eos;
auto add_simple_msg = [&](auto role, auto content) {
common_chat_msg msg;
msg.role = role;
@@ -544,8 +552,21 @@ common_chat_templates_ptr common_chat_templates_init(
default_template_src = CHATML_TEMPLATE_SRC;
}
}
// TODO @ngxson : this is a temporary hack to prevent chat template from throwing an error
// Ref: https://github.com/ggml-org/llama.cpp/pull/15230#issuecomment-3173959633
if (default_template_src.find("<|channel|>") != std::string::npos
// search for the error message and patch it
&& default_template_src.find("in message.content or") != std::string::npos) {
string_replace_all(default_template_src,
"{%- if \"<|channel|>analysis<|message|>\" in message.content or \"<|channel|>final<|message|>\" in message.content %}",
"{%- if false %}");
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
bool add_bos = false;
bool add_eos = false;
if (model) {
const auto * vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
@@ -560,9 +581,13 @@ common_chat_templates_ptr common_chat_templates_init(
};
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
add_bos = llama_vocab_get_add_bos(vocab);
add_eos = llama_vocab_get_add_eos(vocab);
}
common_chat_templates_ptr tmpls(new common_chat_templates());
tmpls->has_explicit_template = has_explicit_template;
tmpls->add_bos = add_bos;
tmpls->add_eos = add_eos;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
@@ -592,6 +617,8 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default:
throw std::runtime_error("Unknown chat format");
}
@@ -600,13 +627,28 @@ const char * common_chat_format_name(common_chat_format format) {
const char * common_reasoning_format_name(common_reasoning_format format) {
switch (format) {
case COMMON_REASONING_FORMAT_NONE: return "none";
case COMMON_REASONING_FORMAT_AUTO: return "auto";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
default:
throw std::runtime_error("Unknown reasoning format");
}
}
common_reasoning_format common_reasoning_format_from_name(const std::string & format) {
if (format == "none") {
return COMMON_REASONING_FORMAT_NONE;
} else if (format == "auto") {
return COMMON_REASONING_FORMAT_AUTO;
} else if (format == "deepseek") {
return COMMON_REASONING_FORMAT_DEEPSEEK;
} else if (format == "deepseek-legacy") {
return COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY;
}
throw std::runtime_error("Unknown reasoning format: " + format);
}
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
std::string arguments;
if (builder.is_partial()) {
@@ -748,10 +790,10 @@ static std::string apply(
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (string_starts_with(result, tmpl.bos_token())) {
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
if (string_ends_with(result, tmpl.eos_token())) {
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());
}
return result;
@@ -1289,6 +1331,26 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
tool_calls_end);
}
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto prompt = apply(tmpl, inputs);
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_GPT_OSS;
// TODO: support tool calls in GPT-OSS?
return data;
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
// TODO @ngxson : this won't work with --special enabled, we should fix that
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
}
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
@@ -1698,6 +1760,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Pass thinking context for Granite template
json additional_context = {
{"thinking", inputs.enable_thinking},
};
data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
data.format = COMMON_CHAT_FORMAT_GRANITE;
if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
if (!inputs.tools.is_null()) {
// Granite uses <|tool_call|> followed by JSON list
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name +
"-args", {
{"type", "object"},
{"properties", {
{"name", {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
})));
});
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");
if (data.thinking_forced_open) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
} else {
builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
}
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
"<|tool_call|>"
});
data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
"<|tool_call|>",
};
});
} else {
// Handle thinking tags for non-tool responses
if (data.thinking_forced_open && inputs.enable_thinking) {
data.grammar_lazy = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space .* \"</response>\" space");
});
data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
};
}
}
return data;
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags using regex
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
if (auto res = builder.try_find_regex(response_regex)) {
// Extract the content between the tags (capture group 1)
auto content = builder.str(res->groups[1]);
builder.add_content(content);
builder.move_to(res->groups[0].end);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.add_tool_calls(tool_calls_data.json)) {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content(builder.consume_rest());
}
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@@ -1731,6 +1911,8 @@ static common_chat_params common_chat_templates_apply_jinja(
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
params.add_bos = inputs.add_bos;
params.add_eos = inputs.add_eos;
params.extra_context = json::object();
for (auto el : inputs.chat_template_kwargs) {
@@ -1767,11 +1949,21 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_command_r7b(tmpl, params);
}
// Granite (IBM) - detects thinking / tools support
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
return common_chat_params_init_granite(tmpl, params);
}
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_hermes_2_pro(tmpl, params);
}
// GPT-OSS
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_gpt_oss(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -1822,6 +2014,7 @@ static common_chat_params common_chat_templates_apply_legacy(
int alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
@@ -1923,6 +2116,12 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
+5
View File
@@ -109,6 +109,8 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
@@ -127,6 +129,8 @@ struct common_chat_templates_inputs {
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;
bool add_bos = false;
bool add_eos = false;
};
struct common_chat_params {
@@ -187,6 +191,7 @@ std::string common_chat_format_example(
const char* common_chat_format_name(common_chat_format format);
const char* common_reasoning_format_name(common_reasoning_format format);
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
+3 -1
View File
@@ -236,8 +236,10 @@ struct common_params_diffusion {
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};
struct common_params {
@@ -394,7 +396,7 @@ struct common_params {
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
int reasoning_budget = -1;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
+487 -101
View File
@@ -28,6 +28,14 @@ if TYPE_CHECKING:
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
from gguf.vocab import MistralTokenizerType, MistralVocab
from mistral_common.tokens.tokenizers.base import TokenizerVersion
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
logger = logging.getLogger("hf-to-gguf")
@@ -81,6 +89,8 @@ class ModelBase:
block_count: int
tensor_map: gguf.TensorNameMap
is_mistral_format: bool = False
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
@@ -106,16 +116,17 @@ class ModelBase:
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
self.tensor_names = set(name for name in remote_tensors.keys())
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
for name, remote_tensor in remote_tensors.items():
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
self.get_tensors = get_remote_tensors
else:
self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
prefix = "model" if not self.is_mistral_format else "consolidated"
self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
self.tensor_names = None
self.metadata_override = metadata_override
self.model_name = model_name
@@ -153,19 +164,23 @@ class ModelBase:
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_names_from_parts: set[str] = set()
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
index_name += ".index.json"
index_file = self.dir_model / index_name
if not self.is_mistral_format:
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
index_name += ".index.json"
index_file = self.dir_model / index_name
if index_file.is_file():
self.tensor_names = set()
logger.info(f"gguf: loading model weight map from '{index_name}'")
with open(index_file, "r", encoding="utf-8") as f:
index: dict[str, Any] = json.load(f)
weight_map = index.get("weight_map")
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
self.tensor_names.update(weight_map.keys())
if index_file.is_file():
self.tensor_names = set()
logger.info(f"gguf: loading model weight map from '{index_name}'")
with open(index_file, "r", encoding="utf-8") as f:
index: dict[str, Any] = json.load(f)
weight_map = index.get("weight_map")
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
self.tensor_names.update(weight_map.keys())
else:
self.tensor_names = tensor_names_from_parts
weight_map = {}
else:
self.tensor_names = tensor_names_from_parts
weight_map = {}
@@ -426,7 +441,12 @@ class ModelBase:
return part_names
@staticmethod
def load_hparams(dir_model: Path):
def load_hparams(dir_model: Path, is_mistral_format: bool):
if is_mistral_format:
with open(dir_model / "params.json", "r", encoding="utf-8") as f:
config = json.load(f)
return config
try:
# for security reason, we don't allow loading remote code by default
# if a model need remote code, we will fallback to config.json
@@ -476,7 +496,10 @@ class TextModel(ModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hf_arch = get_model_architecture(self.hparams, self.model_type)
if not self.is_mistral_format:
self.hf_arch = get_model_architecture(self.hparams, self.model_type)
else:
self.hf_arch = ""
if "text_config" in self.hparams:
# move the text_config to the root level
@@ -542,14 +565,14 @@ class TextModel(ModelBase):
self.gguf_writer.add_head_count(n_head)
logger.info(f"gguf: head count = {n_head}")
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
if (n_head_kv := self.find_hparam(["num_key_value_heads", "n_kv_heads"], optional=True)) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)
logger.info(f"gguf: key-value head count = {n_head_kv}")
if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
logger.info(f"gguf: rope theta = {rope_theta}")
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
if (f_rms_eps := self.find_hparam(["rms_norm_eps", "norm_eps"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
@@ -1210,12 +1233,19 @@ class MmprojModel(ModelBase):
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
# get n_embd of the text model
if "text_config" not in self.hparams:
self.hparams["text_config"] = {}
if "audio_config" not in self.hparams:
self.hparams["audio_config"] = {}
text_config = {**self.hparams, **self.hparams["text_config"]}
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
if not self.is_mistral_format:
if "text_config" not in self.hparams:
self.hparams["text_config"] = {}
if "audio_config" not in self.hparams:
self.hparams["audio_config"] = {}
text_config = {**self.hparams, **self.hparams["text_config"]}
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
else:
text_config = {
k: v for k, v in self.hparams.items() if k not in ["vision_encoder", "audio_encoder"]
}
self.n_embd_text = text_config.get("hidden_dim", 0)
assert self.n_embd_text > 0, "n_embd not found in hparams"
# move vision config to the top level, while preserving the original hparams in global_config
@@ -1236,11 +1266,13 @@ class MmprojModel(ModelBase):
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
# load preprocessor config
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
if not self.is_mistral_format:
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config.get("vision_config")
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
return self.global_config.get(config_name)
def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config.get("audio_config")
@@ -1264,8 +1296,11 @@ class MmprojModel(ModelBase):
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
# preprocessor config
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
self.gguf_writer.add_vision_image_mean(image_mean)
self.gguf_writer.add_vision_image_std(image_std)
if self.has_audio_encoder:
self.gguf_writer.add_clip_has_audio_encoder(True)
@@ -1924,11 +1959,63 @@ class LlamaModel(TextModel):
if self.hf_arch == "VLlama3ForCausalLM":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
def _set_vocab_mistral(self):
vocab = MistralVocab(self.dir_model)
logger.info(
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
)
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
assert len(tokens) == vocab.vocab_size, (
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
)
if vocab.tokenizer_type == MistralTokenizerType.tekken:
self.gguf_writer.add_tokenizer_pre("tekken")
self.gguf_writer.add_token_merges(
vocab.extract_vocab_merges_from_model()
)
logger.info(
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
)
self.gguf_writer.add_bos_token_id(vocab.bos_id)
self.gguf_writer.add_eos_token_id(vocab.eos_id)
self.gguf_writer.add_unk_token_id(vocab.unk_id)
self.gguf_writer.add_pad_token_id(vocab.pad_id)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_vocab_size(vocab.vocab_size)
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(False)
template_dir = Path(__file__).parent / "models/templates/"
template = MistralModel.get_community_chat_template(vocab, template_dir)
self.gguf_writer.add_chat_template(template)
def set_vocab(self):
if self.is_mistral_format:
return self._set_vocab_mistral()
path_tekken_json = self.dir_model / "tekken.json"
path_tokenizer_json = self.dir_model / "tokenizer.json"
if path_tekken_json.is_file() and not path_tokenizer_json.is_file():
return self.set_vocab_tekken()
self._set_vocab_mistral()
try:
self._set_vocab_sentencepiece()
@@ -1962,56 +2049,12 @@ class LlamaModel(TextModel):
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)
def set_vocab_tekken(self):
vocab = gguf.vocab.MistralVocab(self.dir_model)
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
assert len(tokens) == vocab.vocab_size, (
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
)
if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken:
self.gguf_writer.add_tokenizer_pre("tekken")
self.gguf_writer.add_token_merges(
vocab.extract_vocab_merges_from_model()
)
logger.info(
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
)
self.gguf_writer.add_bos_token_id(vocab.bos_id)
self.gguf_writer.add_eos_token_id(vocab.eos_id)
self.gguf_writer.add_unk_token_id(vocab.unk_id)
self.gguf_writer.add_pad_token_id(vocab.pad_id)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_vocab_size(vocab.vocab_size)
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(False)
script_dir = Path(__file__).parent
template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja"
with open(template_path, "r", encoding="utf-8") as f:
template = f.read()
self.gguf_writer.add_chat_template(template)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if not self.is_mistral_format:
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if (rope_dim := hparams.get("head_dim")) is None:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
@@ -2033,13 +2076,25 @@ class LlamaModel(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
n_head = self.find_hparam(["n_heads", "num_attention_heads"])
n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"])
vision_prefixes = [
"vision_encoder.",
"vision_language_adapter.",
"patch_merger.",
"pre_mm_projector_norm",
]
is_multimodal_tensor = "vision_tower" in name \
or "vision_model" in name \
or "audio_tower" in name \
or "model.connector" in name \
or "multi_modal_projector" in name
or "multi_modal_projector" in name \
or any(
name.startswith(prefix)
for prefix in vision_prefixes
)
if is_multimodal_tensor:
return [] # skip vision tensors
@@ -2155,13 +2210,18 @@ class LlavaVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams["model_type"] == "pixtral":
if self.hparams.get("model_type") == "pixtral":
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
logger.info(f"Image break token id: {self.img_break_tok_id}")
elif self.is_mistral_format:
# hparams is already vision config here so norm_eps is only defined in global_config.
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
logger.info(f"Image break token id: {self.img_break_tok_id}")
def get_token_id(self, token: str) -> int:
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
@@ -2175,7 +2235,7 @@ class LlavaVisionModel(MmprojModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if hparams["model_type"] == "pixtral":
if hparams.get("model_type") == "pixtral":
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
@@ -2193,18 +2253,30 @@ class LlavaVisionModel(MmprojModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
n_head = self.hparams["num_attention_heads"]
n_head = (
self.hparams["num_attention_heads"] if not self.is_mistral_format else self.find_vparam(["num_attention_heads"])
)
n_kv_head = n_head
if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."):
valid_prefixes = (
"multi_modal_projector.",
"vision_tower.",
"vision_encoder.",
"vision_language_adapter.",
"patch_merger.",
"pre_mm_projector_norm",
)
if any(name.startswith(prefix) for prefix in valid_prefixes):
# process vision tensors
if name.endswith(("q_proj.weight", "q_proj.bias")):
if name.endswith(("q_proj.weight", "q_proj.bias")) and not self.is_mistral_format:
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
if name.endswith(("k_proj.weight", "k_proj.bias")) and not self.is_mistral_format:
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
return [(self.map_tensor_name(name), data_torch)]
if self.img_break_tok_id > 0 and "embed_tokens.weight" in name:
embed_key = "embed_tokens.weight" if not self.is_mistral_format else "tok_embeddings.weight"
if self.img_break_tok_id > 0 and embed_key in name:
logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
img_break_embd = data_torch[self.img_break_tok_id]
@@ -3328,7 +3400,13 @@ class Qwen25OmniModel(Qwen2VLVisionModel):
@ModelBase.register("InternVisionModel")
class InternVisionModel(MmprojModel):
def set_gguf_parameters(self):
assert self.hparams_vision is not None
if isinstance(self.hparams_vision['image_size'], list):
self.hparams_vision['image_size'] = self.hparams_vision['image_size'][0]
if isinstance(self.hparams_vision['patch_size'], list):
self.hparams_vision['patch_size'] = self.hparams_vision['patch_size'][0]
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
@@ -3352,14 +3430,30 @@ class InternVisionModel(MmprojModel):
return gguf.GGMLQuantizationType.F32
return False
def _mapping_interns1_name(self, name):
names_map = {
"model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias",
"model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight",
"model.multi_modal_projector.linear_1.bias": "mlp1.1.bias",
"model.multi_modal_projector.linear_1.weight": "mlp1.1.weight",
"model.multi_modal_projector.linear_2.bias": "mlp1.3.bias",
"model.multi_modal_projector.linear_2.weight": "mlp1.3.weight",
}
if name in names_map:
name = names_map[name]
return name
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("vision_model") or name.startswith("mlp"):
vision_prefix = ['vision_model', 'mlp', 'model.vision_tower', 'model.multi_modal_projector']
# deal with intern-s1 special case
name = self._mapping_interns1_name(name)
if any([name.startswith(prefix) for prefix in vision_prefix]):
# process visual tensors
# correct name
if name.startswith("vision_model"):
name = "vision_tower." + name
if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"):
if (".ls" in name or ".lambda_" in name or "position_embedding" in name) and not name.endswith(".weight"):
name += ".weight"
# split QKV tensors if needed
if ".qkv." in name:
@@ -3445,6 +3539,10 @@ class Qwen2MoeModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
name = name.replace("language_model.", "") # InternVL
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
# skip visual tensors
return []
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
@@ -3498,6 +3596,85 @@ class Qwen3Model(Qwen2Model):
class Qwen3MoeModel(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3MOE
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hparams = ModelBase.load_hparams(self.dir_model, False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
def set_vocab(self):
# deal with intern-s1
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
self._set_vocab_interns1()
return
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()
def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
vocab_size = self.hparams.get("vocab_size", len(vocab))
assert max(vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
additional_special_tokens = []
if special_tokens_map_file.is_file():
with open(special_tokens_map_file, encoding = 'utf-8') as f:
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
if tokenizer_cfg_file.is_file():
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
for token in additional_special_tokens:
if token in token2ids_map:
special_vocab._set_special_token(token, token2ids_map[token])
special_vocab._set_special_token('eos', 151645)
special_vocab._set_special_token("bos", 151643)
special_vocab.add_to_gguf(self.gguf_writer)
@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
@@ -4578,7 +4755,7 @@ class NomicBertModel(BertModel):
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
hparams = kwargs.pop("hparams", None)
if hparams is None:
hparams = ModelBase.load_hparams(dir_model)
hparams = ModelBase.load_hparams(dir_model, False)
self.is_moe = bool(hparams.get("moe_every_n_layers"))
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
@@ -7950,6 +8127,130 @@ class SmolLM3Model(LlamaModel):
self.gguf_writer.add_chat_template(chat_template)
@ModelBase.register("GptOssForCausalLM")
class GptOssModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPT_OSS
def transform_nibble_layout(self, tensor):
assert tensor.dtype == torch.uint8
assert tensor.shape[-1] == 16
# swap nibbles
t_lo = tensor & 0x0F
t_hi = tensor & 0xF0
t_swapped = (t_lo << 4) | (t_hi >> 4)
tensor = t_swapped
# transform aaaa...bbbb... to abababab...
blk_a, blk_b = tensor.chunk(2, dim=-1)
# get a_
blk_a0 = (blk_a & 0xF0).view(-1, 1)
blk_a1 = (blk_a << 4).view(-1, 1)
blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape)
# get _b
blk_b0 = (blk_b >> 4).view(-1, 1)
blk_b1 = (blk_b & 0x0F).view(-1, 1)
blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape)
# swap once more
out = blk_a | blk_b
out_h = out & 0xF0
out_l = out & 0x0F
out = (out_h >> 4) | (out_l << 4)
return out
def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
assert blocks.dtype == torch.uint8
assert scales.dtype == torch.uint8
scales = scales.unsqueeze(-1)
assert len(blocks.shape) == 4
assert len(scales.shape) == 4
blocks = self.transform_nibble_layout(blocks)
new_data = torch.concat((scales, blocks), dim=-1)
new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
# flatten last dim
new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
new_data = new_data.numpy()
self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
blocks0: Tensor = torch.zeros(1)
blocks1: Tensor = torch.zeros(1)
# we assume that tensors are loaded in the correct order
for name, data_torch in self.get_tensors():
if "mlp.experts.down_proj_blocks" in name:
blocks0 = data_torch
elif "mlp.experts.down_proj_scales" in name:
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
self.repack_mxfp4(new_name, blocks0, data_torch)
elif "mlp.experts.gate_up_proj_blocks" in name:
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
elif "mlp.experts.gate_up_proj_scales" in name:
scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
self.repack_mxfp4(new_name_gate, blocks0, scales0)
self.repack_mxfp4(new_name_up, blocks1, scales1)
return []
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if "sinks" in name:
name += ".weight"
# correct naming for down_proj
if "down_proj" in name:
if name.endswith("_bias"):
name = name.replace("down_proj_bias", "down_proj.bias")
elif "_blocks" not in name and "_scales" not in name:
logger.warning(f"{name} is not in MXFP4, performance may be degraded")
name = name.replace("down_proj", "down_proj.weight")
data_torch = data_torch.transpose(-1, -2)
else:
# otherwise, it should already be repacked to ggml MXFP4 format
return []
# split the gate_up into gate and up
if "gate_up_proj" in name:
if name.endswith("_bias"):
name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
return [
(self.map_tensor_name(name_gate), gate_proj_bias),
(self.map_tensor_name(name_up), up_proj_bias)
]
elif "_blocks" not in name and "_scales" not in name:
logger.warning(f"{name} is not in MXFP4, performance may be degraded")
name_up = name.replace("gate_up_proj", "up_proj.weight")
name_gate = name.replace("gate_up_proj", "gate_proj.weight")
data_torch = data_torch.transpose(-1, -2)
gate_proj_weight, up_proj_weight = data_torch[:, ::2, :], data_torch[:, 1::2, :]
return [
(self.map_tensor_name(name_gate), gate_proj_weight),
(self.map_tensor_name(name_up), up_proj_weight)
]
else:
# otherwise, it should already be repacked to ggml MXFP4 format
return []
return [(self.map_tensor_name(name), data_torch)]
def set_vocab(self):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])
rope_scaling = self.hparams.get("rope_scaling") or {}
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
@ModelBase.register("Lfm2ForCausalLM")
@ModelBase.register("LFM2ForCausalLM")
class LFM2Model(TextModel):
@@ -8075,6 +8376,77 @@ class SmallThinkerModel(TextModel):
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
class MistralModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA
model_name = "Mistral"
hf_arch = ""
is_mistral_format = True
undo_permute = False
@staticmethod
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path):
assert TokenizerVersion is not None, "mistral_common is not installed"
assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), (
f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}"
)
if vocab.tokenizer.version == TokenizerVersion.v1:
return "mistral-v1"
elif vocab.tokenizer.version == TokenizerVersion.v3 and vocab.tokenizer_type == MistralTokenizerType.spm:
return "mistral-v3"
elif vocab.tokenizer.version == TokenizerVersion.v3 and vocab.tokenizer_type == MistralTokenizerType.tekken:
return "mistral-v3-tekken"
elif vocab.tokenizer.version == TokenizerVersion.v7 and vocab.tokenizer_type == MistralTokenizerType.spm:
return "mistral-v7"
elif vocab.tokenizer.version == TokenizerVersion.v7 and vocab.tokenizer_type == MistralTokenizerType.tekken:
return "mistral-v7-tekken"
elif vocab.tokenizer.version == TokenizerVersion.v11:
template_file = "Mistral-Small-3.2-24B-Instruct-2506.jinja"
elif vocab.tokenizer.version == TokenizerVersion.v13:
template_file = "unsloth-mistral-Devstral-Small-2507.jinja"
else:
raise ValueError(f"Unknown tokenizer type: {vocab.tokenizer_type} and version {vocab.tokenizer.version}")
template_path = templates_dir / template_file
if not template_path.exists():
raise FileNotFoundError(f"Template file not found: {template_path}")
with open(template_path, "r", encoding="utf-8") as f:
template = f.read()
return template
class PixtralModel(LlavaVisionModel):
model_name = "Pixtral"
hf_arch = ""
is_mistral_format = True
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_vision_attention_layernorm_eps(
self.find_hparam(["norm_eps"])
)
self.gguf_writer.add_rope_freq_base(self.find_vparam(["rope_theta"]))
self.gguf_writer.add_vision_use_silu(True)
# spatial_merge_size
if self.find_vparam(["mm_projector_id"]) == "patch_merge":
self.gguf_writer.add_vision_spatial_merge_size(
self.find_vparam(["spatial_merge_size"])
)
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
if name == "vision_language_adapter.w_in.weight":
return "mm.1.weight"
elif name == "vision_language_adapter.w_out.weight":
return "mm.2.weight"
return super().map_tensor_name(name, try_suffixes)
###### CONVERSION LOGIC ######
@@ -8089,6 +8461,7 @@ class LazyTorchTensor(gguf.LazyBase):
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.uint8: np.uint8,
}
# used for safetensors slices
@@ -8224,6 +8597,10 @@ def parse_args() -> argparse.Namespace:
"--mmproj", action="store_true",
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
)
parser.add_argument(
"--mistral-format", action="store_true",
help="Whether the model is stored following the Mistral format.",
)
args = parser.parse_args()
if not args.print_supported_models and args.model is None:
@@ -8329,17 +8706,25 @@ def main() -> None:
if "mmproj" not in fname_out.name:
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
is_mistral_format = args.mistral_format
with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT
hparams = ModelBase.load_hparams(dir_model)
model_architecture = get_model_architecture(hparams, model_type)
logger.info(f"Model architecture: {model_architecture}")
try:
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
except NotImplementedError:
logger.error(f"Model {model_architecture} is not supported")
sys.exit(1)
hparams = ModelBase.load_hparams(dir_model, is_mistral_format)
if not is_mistral_format:
model_architecture = get_model_architecture(hparams, model_type)
logger.info(f"Model architecture: {model_architecture}")
try:
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
except NotImplementedError:
logger.error(f"Model {model_architecture} is not supported")
sys.exit(1)
elif args.mmproj:
assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
model_class = PixtralModel
else:
model_class = MistralModel
model_instance = model_class(dir_model, output_type, fname_out,
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
@@ -8348,7 +8733,8 @@ def main() -> None:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id)
remote_hf_model_id=hf_repo_id,
)
if args.vocab_only:
logger.info("Exporting model vocab...")
+1 -1
View File
@@ -340,7 +340,7 @@ if __name__ == '__main__':
sys.exit(1)
else:
logger.info(f"Loading base model: {dir_base_model.name}")
hparams = ModelBase.load_hparams(dir_base_model)
hparams = ModelBase.load_hparams(dir_base_model, False)
with torch.inference_mode():
try:
+1 -1
View File
@@ -13,7 +13,7 @@ If there are differences in usage, please refer to the official build [documenta
Clone llama.cpp:
```bash
git clone https://github.com/ggerganov/llama.cpp
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
```
+1 -1
View File
@@ -12,7 +12,7 @@ If there are differences in usage, please refer to the official build [documenta
Clone llama.cpp:
```bash
git clone https://github.com/ggerganov/llama.cpp
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
```
+1
View File
@@ -176,6 +176,7 @@ option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM"
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)
option(GGML_VULKAN "ggml: use Vulkan" OFF)
+42 -40
View File
@@ -106,7 +106,7 @@ if(NOT TARGET ggml::ggml)
find_library(GGML_LIBRARY ggml
REQUIRED
HINTS ${GGML_LIB_DIR} ${GGML_BACKEND_DIR}
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)
add_library(ggml::ggml UNKNOWN IMPORTED)
@@ -125,54 +125,56 @@ if(NOT TARGET ggml::ggml)
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
set(_ggml_all_targets "")
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
if (NOT GGML_BACKEND_DL)
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
REQUIRED
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
REQUIRED
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
INTERFACE_COMPILE_FEATURES c_std_90
POSITION_INDEPENDENT_CODE ON)
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
if(is_cpu_variant)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
endif()
else()
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
INTERFACE_COMPILE_FEATURES c_std_90
POSITION_INDEPENDENT_CODE ON)
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
if(is_cpu_variant)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
endif()
else()
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
endif()
endif()
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
endforeach()
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
endif()
endif()
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
endforeach()
endif()
list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
set_target_properties(ggml::ggml
+37 -1
View File
@@ -304,6 +304,16 @@
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_TERNARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
@@ -395,7 +405,8 @@ extern "C" {
// GGML_TYPE_IQ4_NL_4_4 = 36,
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_COUNT = 39,
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_COUNT = 40,
};
// precision
@@ -430,6 +441,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
};
// available tensor operations:
@@ -438,6 +450,7 @@ extern "C" {
GGML_OP_DUP,
GGML_OP_ADD,
GGML_OP_ADD_ID,
GGML_OP_ADD1,
GGML_OP_ACC,
GGML_OP_SUB,
@@ -557,6 +570,7 @@ extern "C" {
GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU,
GGML_GLU_OP_SWIGLU_OAI,
GGML_GLU_OP_GEGLU_ERF,
GGML_GLU_OP_GEGLU_QUICK,
@@ -831,6 +845,13 @@ extern "C" {
struct ggml_tensor * b,
enum ggml_type type);
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
GGML_API struct ggml_tensor * ggml_add_id(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * ids);
GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -1198,6 +1219,13 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_swiglu_oai(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float alpha,
float limit);
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
@@ -1570,6 +1598,10 @@ extern "C" {
float scale,
float max_bias);
GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -2052,6 +2084,10 @@ extern "C" {
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
const struct ggml_tensor * a);
GGML_API void ggml_flash_attn_ext_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
// TODO: needs to be adapted to ggml_flash_attn_ext
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
+1
View File
@@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
+7 -2
View File
@@ -1071,6 +1071,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
}
}
}
// if the node is still unassigned, assign it to the first backend that supports it
for (int b = 0; b < sched->n_backends && *cur_backend_id == -1; b++) {
ggml_backend_sched_set_if_supported(sched, node, b, cur_backend_id);
}
GGML_ASSERT(*cur_backend_id != -1);
}
// pass 5: split graph, find tensors that need to be copied
@@ -1098,7 +1103,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
const int node_backend_id = tensor_backend_id(node);
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
// check if we should start a new split based on the sources of the current node
bool need_new_split = false;
@@ -1156,7 +1161,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
size_t src_id = hash_id(src);
const int src_backend_id = sched->hv_tensor_backend_ids[src_id];
assert(src_backend_id != -1); // all inputs should be assigned by now
GGML_ASSERT(src_backend_id != -1); // all inputs should be assigned by now
if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {
+4 -4
View File
@@ -281,10 +281,10 @@ ggml_backend_t ggml_backend_blas_init(void) {
ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_blas_guid(),
/* .interface = */ blas_backend_i,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
/* .context = */ ctx,
/* .guid = */ ggml_backend_blas_guid(),
/* .iface = */ blas_backend_i,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
/* .context = */ ctx,
};
#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
+14
View File
@@ -31,6 +31,13 @@ string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}")
set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}")
string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION)
message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}")
option(USE_ACL_GRAPH "Enable CANN graph execution (ACL graph mode)" OFF)
if(USE_ACL_GRAPH AND (SOC_TYPE_MAJOR_SN STREQUAL "310P" OR SOC_TYPE_COMPILE_OPTION STREQUAL "ASCEND_310P"))
message(FATAL_ERROR
"CANN Graph (ACL graph mode) is not supported on 310P devices. "
"Please build with -DUSE_ACL_GRAPH=OFF or use a supported SOC.")
endif()
if (CANN_INSTALL_DIR)
# Only Support Linux.
@@ -68,6 +75,13 @@ if (CANN_INSTALL_DIR)
target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}")
if (USE_ACL_GRAPH)
target_compile_definitions(ggml-cann PRIVATE USE_ACL_GRAPH)
message(STATUS "CANN: USE_ACL_GRAPH is enabled.")
else()
message(STATUS "CANN: USE_ACL_GRAPH is disabled.")
endif()
message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
else()
+245 -380
View File
@@ -753,69 +753,55 @@ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0];
aclTensor* acl_src = ggml_cann_create_tensor(src0);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
if (ggml_are_same_shape(src0, dst)) {
aclTensor* acl_src = ggml_cann_create_tensor(src0);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
if (dst->type == src0->type) {
cann_copy(ctx, acl_src, acl_dst);
} else {
aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type));
}
ggml_cann_release_resources(ctx, acl_src, acl_dst);
} else {
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
if (dst->type == src0->type) {
size_t cpy_size = ggml_nbytes(dst);
ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size,
ACL_MEMCPY_DEVICE_TO_DEVICE);
return;
} else {
ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(),
ggml_nelements(dst) * ggml_type_size(dst->type));
void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), src0->ne, src_trans_nb,
GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
size_t cpy_size = ggml_nbytes(dst);
ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
ACL_MEMCPY_DEVICE_TO_DEVICE);
ggml_cann_release_resources(ctx, src_trans_tensor);
return;
}
} else if (ggml_is_contiguous(dst)) {
ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(), ggml_nelements(dst) * ggml_type_size(dst->type));
void* src_trans_buffer = src_buffer_allocator.get();
void* src_trans_buffer = src0->data;
ggml_cann_pool_alloc src_buffer_allocator;
if (!ggml_is_contiguous(src0)) {
aclTensor* acl_src = ggml_cann_create_tensor(src0);
src_buffer_allocator.alloc(ctx.pool(),
ggml_nelements(src0) * ggml_type_size(src0->type));
src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = ggml_type_size(dst->type);
src_trans_nb[0] = ggml_type_size(src0->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), src0->ne, src_trans_nb,
src_trans_buffer, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src0->ne, src_trans_nb,
GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
size_t cpy_size = ggml_nbytes(dst);
ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
ACL_MEMCPY_DEVICE_TO_DEVICE);
ggml_cann_release_resources(ctx, src_trans_tensor);
return;
} else {
GGML_ABORT("Unsupport dst is not tontiguous.");
cann_copy(ctx, acl_src, src_trans_tensor);
ggml_cann_release_resources(ctx, acl_src, src_trans_tensor);
}
size_t src_reshape_nb[GGML_MAX_DIMS];
src_reshape_nb[0] = ggml_type_size(src0->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_reshape_nb[i] = src_reshape_nb[i - 1] * dst->ne[i - 1];
}
aclTensor* trans_acl_src = ggml_cann_create_tensor(src_trans_buffer,
ggml_cann_type_mapping(src0->type),ggml_type_size(src0->type),
dst->ne, src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
if (dst->type == src0->type) {
cann_copy(ctx, trans_acl_src, acl_dst);
} else {
aclnn_cast(ctx, trans_acl_src, acl_dst, ggml_cann_type_mapping(dst->type));
}
ggml_cann_release_resources(ctx, trans_acl_src, acl_dst);
}
ggml_cann_release_resources(ctx, acl_src, acl_dst);
return;
}
/**
@@ -1330,160 +1316,196 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
}
/**
* @brief Applies the Alibi (Attention with Linear Biases) mechanism to the
* @details This function implements the Alibi mechanism, which introduces
* learnable biases into the attention scores to simulate relative
* position encoding without the need for explicit positional
* embeddings.
* @brief Generate a range of values and apply a scalar base exponentiation.
*
* @param ctx The backend CANN context for executing operations.
* @param acl_src The source tensor representing the query or key.
* @param acl_position The position tensor containing relative positions.
* @param acl_dst The destination tensor where the result will be stored.
* @param n_head The number of attention heads.
* @param src_ne The dimensions of the source tensor.
* @param src_nb0 The byte size of the first dimension of the source
tensor.
* @param max_bias The maximum bias value used in the Alibi mechanism.
* @param dst The destination tensor object for additional metadata.
* This function creates an evenly spaced sequence from `start` to `stop` (exclusive),
* with step size `step`, stores it in a temporary buffer, and then computes:
*
* The function performs the following steps:
* 1. Calculates the logarithm floor of the number of heads to determine the
base for bias calculation.
* 2. Initializes arrays with arithmetic sequences and fills them with bias
values.
* 3. Computes the bias tensor based on the calculated biases and arithmetic
sequences.
* 4. Reshapes the bias tensor to match the dimensions of the input tensors.
* 5. Multiplies the position tensor by the bias tensor.
* 6. Adds the result of the multiplication to the source tensor to produce the
final output.
* @f[
* slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size
* @f]
*
* The results are written to the provided @p slope_buffer.
*
* @param ctx CANN backend context for memory allocation and operator execution.
* @param slope_buffer Pointer to the output buffer (float array) for the computed slope values.
* @param m Scalar base for the exponentiation.
* @param size Number of elements in the generated sequence.
* @param start Starting exponent offset.
* @param stop Stopping exponent offset (exclusive).
* @param step Step size for the exponent increment.
*/
static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_position, aclTensor* acl_dst,
const int n_head, int64_t* src_ne, const size_t src_nb0,
float max_bias, ggml_tensor* dst) {
const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
GGML_ASSERT(src_nb0 == sizeof(float));
GGML_ASSERT(n_head == src_ne[2]);
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
float m, int64_t size, float start, float stop, float step){
int64_t ne[] = {size};
size_t nb[] = {sizeof(float)};
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
void* arange_buffer = arange_allocator.get();
float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
aclTensor* arange_tensor = ggml_cann_create_tensor(
arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
// init arange
ggml_cann_pool_alloc arange_allocator(ctx.pool(),
ne2_ne3 * ggml_type_size(dst->type));
void* tmp_arange_buffer = arange_allocator.get();
aclTensor* slope_tensor = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
// arange1: [1, ..., n_heads_log2_floor+1)
float start = 1;
float stop = n_heads_log2_floor + 1;
float step = 1;
int64_t n_elements_arange = n_heads_log2_floor;
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
size_t tmp_arange1_nb[] = {sizeof(dst->type)};
aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
aclTensor* tmp_arange2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
// arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
start = 1;
stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
step = 2;
n_elements_arange = ne2_ne3 - n_heads_log2_floor;
int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_arange2_nb[] = {sizeof(dst->type)};
aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
(char*)tmp_arange_buffer +
n_heads_log2_floor * ggml_type_size(dst->type),
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
n_elements_arange);
}
// init mk_base
ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
ne2_ne3 * ggml_type_size(dst->type));
void* tmp_mk_base_buffer = mk_base_allocator.get();
int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
aclTensor* tmp_mk_base2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
(char*)tmp_mk_base_buffer +
n_heads_log2_floor * ggml_type_size(dst->type),
ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
}
// init mk
int64_t tmp_mk_base_ne[] = {ne2_ne3};
size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
// reshape mk
int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
size_t tmp_mk_nb[GGML_MAX_DIMS];
tmp_mk_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
}
aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
ACL_FORMAT_ND);
// acl_position * mk
int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
size_t tmp_output_nb[GGML_MAX_DIMS];
tmp_output_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
}
ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
void* tmp_output_buffer = output_allocator.get();
aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
tmp_output_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
ACL_FORMAT_ND);
aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
// add
aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor);
ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor);
}
void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
/**
* @brief Compute slope values for multiple attention heads based on ALiBi bias parameters.
*
* This function generates slope values for each attention head according to the ALiBi
* (Attention with Linear Biases) method. It splits the computation into two ranges depending
* on whether the head index is less than @p n_head_log2 or not, and uses different base values
* (`m0` and `m1`) for the exponentiation.
*
* @f[
* slope[h] =
* \begin{cases}
* m_0^{(h + 1)}, & h < n\_head\_log2 \\
* m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2
* \end{cases}
* \quad , \quad \text{if } max\_bias > 0
* @f]
*
* If @p max_bias <= 0, all slope values are set to 1.0.
*
* @param ctx CANN backend context for memory allocation and operator execution.
* @param n_head Total number of attention heads.
* @param slope_buffer Pointer to the output buffer (float array) for storing slopes.
* @param max_bias Maximum bias value for slope computation.
*
*/
static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
void* slope_buffer, float max_bias) {
const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// const float slope = (max_bias > 0.0f) ?
// h < n_head_log2 ?
// powf(m0, h + 1) :
// powf(m1, 2*(h - n_head_log2) + 1) :
// 1.0f;
// arange1
float start = 0 + 1;
float end = (n_head_log2 - 1) + 1;
float step = 1;
float count = n_head_log2;
// end needs to be +1 because aclnn uses a left-closed, right-open interval.
aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step);
if (n_head_log2 < n_head) {
// arange2
start = 2 * (n_head_log2 - n_head_log2) + 1;
end = 2 * ((n_head - 1) - n_head_log2) + 1;
step = 2;
count = n_head - n_head_log2;
aclnn_get_slope_inner(
ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
m1, count, start, end + 1, step);
}
}
/**
* @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask.
*
* This function computes the ALiBi slopes for each attention head (if max_bias > 0),
* multiplies them with the attention mask to produce bias tensors, and adds these biases
* to the destination tensor (@p dst).
*
* The function performs necessary broadcasting of the mask and slope tensors to match
* the shape of the destination tensor, then applies element-wise multiplication and addition
* using CANN operators.
*
* @param ctx CANN backend context for memory management and operator execution.
* @param mask Input attention mask tensor, assumed to be contiguous.
* @param dst Destination tensor to which ALiBi biases will be added.
* @param dst_ptr Pointer to the memory of the destination tensor.
* @param max_bias Maximum bias value controlling the slope scaling.
*
* @note
* - Write data into dst_ptr using only the shape information of the dst tensor.
* - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting.
*/
static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
ggml_tensor* dst, void* dst_ptr, float max_bias) {
void* slope_buffer = nullptr;
void* bias_buffer = nullptr;
if (max_bias > 0.0f) {
int64_t n_heads = dst->ne[2];
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
slope_buffer = slope_allocator.get();
ggml_cann_pool_alloc bias_allocator(
ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
bias_buffer = bias_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias);
}
// broadcast for mask, slop and dst;
int64_t nr2 = dst->ne[2] / mask->ne[2];
int64_t nr3 = dst->ne[3] / mask->ne[3];
// broadcast the mask across rows
int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 };
size_t mask_nb[] = {
mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2],
mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3]
};
int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 };
size_t dst_nb[] = {
dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2],
dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3]
};
// slope is a 1 dim tensor, slope.ne2 == dst.ne2
int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 };
size_t slope_nb[GGML_MAX_DIMS + 2];
slope_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1];
}
aclTensor* acl_slope = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT, sizeof(float),
slope_ne, slope_nb, GGML_MAX_DIMS + 2);
aclTensor* acl_mask = ggml_cann_create_tensor(
mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2);
// write data into dst_ptr using only the shape information of the dst tensor.
aclTensor* acl_dst = ggml_cann_create_tensor(
dst_ptr, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst_ne, dst_nb,
GGML_MAX_DIMS + 2);
if (max_bias > 0.0f) {
int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 };
size_t bias_nb[GGML_MAX_DIMS + 2];
bias_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1];
}
aclTensor* bias_tensor = ggml_cann_create_tensor(
bias_buffer, ACL_FLOAT, sizeof(float),
bias_ne, bias_nb, GGML_MAX_DIMS + 2);
aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor);
aclnn_add(ctx, acl_dst, bias_tensor);
ggml_cann_release_resources(ctx, bias_tensor);
} else {
aclnn_add(ctx, acl_dst, acl_mask);
}
ggml_cann_release_resources(ctx, acl_slope, acl_mask, acl_dst);
}
void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_cann_dup(ctx, dst);
}
@@ -1501,118 +1523,41 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
* @param acl_dst The destination tensor where the softmax results will be
* stored.
*/
static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
int64_t dim, aclTensor* acl_dst) {
static void aclnn_softmax(ggml_backend_cann_context & ctx,
aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
}
void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor* src0 = dst->src[0];
ggml_tensor* src1 = dst->src[1]; // mask
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
float scale = 1.0f;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// input mul scale
aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0));
void* src_tensor_buffer = src_tensor_allocator.get();
aclTensor* softmax_tensor = ggml_cann_create_tensor(
src_tensor_buffer, ggml_cann_type_mapping(src0->type),
ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS);
size_t n_bytes = ggml_nbytes(src0);
ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
void* input_mul_scale_buffer = mul_scale_allocator.get();
aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
src0->nb, GGML_MAX_DIMS);
bool inplace = false;
aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false);
// mask
aclTensor* acl_src1_fp32_tensor = nullptr;
aclTensor* tmp_mask_tensor = nullptr;
ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
if (src1) {
const bool use_f16 = src1->type == GGML_TYPE_F16;
if (use_f16) {
// cast to fp32
size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
size_t src1_fp32_nb[GGML_MAX_DIMS];
src1_fp32_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
}
src1_fp32_allocator.alloc(n_bytes);
void* src1_fp32_buffer = src1_fp32_allocator.get();
acl_src1_fp32_tensor = ggml_cann_create_tensor(
src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
src1_fp32_nb, GGML_MAX_DIMS);
aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
ggml_cann_release_resources(ctx, acl_src1);
} else {
acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
}
// broadcast the mask across rows, only use ne11 of ne01 in mask
if (src1->ne[1] != src0->ne[1]) {
// mask shape: [1,1,ne11,ne10]
int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
size_t tmp_mask_nb[GGML_MAX_DIMS];
tmp_mask_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
}
tmp_mask_tensor = ggml_cann_create_tensor(
src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
}
// alibi
const int n_head = src0->ne[2];
const size_t src_nb0 = src0->nb[0];
n_bytes = ggml_nbytes(dst);
ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
void* output_buffer = output_allocator.get();
aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
dst->nb, GGML_MAX_DIMS);
if (max_bias <= 0.0f) {
// slope = 1.0
if (tmp_mask_tensor) {
aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
alibi_output_tensor);
} else {
aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
alibi_output_tensor);
}
} else {
// slope != 1.0
if (tmp_mask_tensor) {
aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
alibi_output_tensor, n_head, src0->ne, src_nb0,
max_bias, dst);
} else {
aclnn_alibi(ctx, acl_input_mul_scale_tensor,
acl_src1_fp32_tensor, alibi_output_tensor, n_head,
src0->ne, src_nb0, max_bias, dst);
}
}
// softmax
aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
ggml_cann_release_resources(ctx, alibi_output_tensor);
} else {
aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias);
}
ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
// softmax
aclnn_softmax(ctx, softmax_tensor, 3, acl_dst);
ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor);
}
/**
@@ -3208,104 +3153,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
// Compute the slope if needed. Derived from ggml_cann_softmax().
if(maxBias != 0.0f){
// alibi
const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
const int64_t n_head = src0->ne[2];
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
// init arange
ggml_cann_pool_alloc arange_allocator(ctx.pool(),
ne2_ne3 * faElemSize);
void* tmp_arange_buffer = arange_allocator.get();
const int64_t n_heads = src0->ne[2];
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
void* slope_buffer = slope_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
// arange1: [1, ..., n_heads_log2_floor+1)
float start = 1;
float stop = n_heads_log2_floor + 1;
float step = 1;
int64_t n_elements_arange = n_heads_log2_floor;
int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
size_t tmp_arange1_nb[] = {faElemSize};
aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, faDataType, faElemSize,
tmp_arange1_ne, tmp_arange1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
aclTensor* tmp_arange2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
// arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
start = 1;
stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
step = 2;
n_elements_arange = ne2_ne3 - n_heads_log2_floor;
int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_arange2_nb[] = {faElemSize};
aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
(char*)tmp_arange_buffer +
n_heads_log2_floor * faElemSize,
faDataType, faElemSize,
tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
n_elements_arange);
int64_t slope_ne[] = {1, 1, n_heads, 1};
size_t slope_nb[GGML_MAX_DIMS];
slope_nb[0] = sizeof(float);
for(int i = 1;i<GGML_MAX_DIMS;i++) {
slope_nb[i] = slope_nb[i-1] * slope_ne[0];
}
// init mk_base
ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
ne2_ne3 * faElemSize);
void* tmp_mk_base_buffer = mk_base_allocator.get();
int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
size_t tmp_mk_base1_nb[] = {faElemSize};
aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, faDataType, faElemSize,
tmp_mk_base1_ne, tmp_mk_base1_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclTensor* slope_tensor = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT, sizeof(float),
slope_ne, slope_nb, GGML_MAX_DIMS);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
aclTensor* tmp_mk_base2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
size_t tmp_mk_base2_nb[] = {faElemSize};
aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
(char*)tmp_mk_base_buffer +
n_heads_log2_floor * faElemSize,
faDataType, faElemSize,
tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
}
// init mk
int64_t tmp_mk_base_ne[] = {ne2_ne3};
size_t tmp_mk_base_nb[] = {faElemSize};
aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, faDataType, faElemSize,
tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
tmp_arange_buffer, faDataType, faElemSize,
tmp_mk_base_ne, tmp_mk_base_nb,
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
// reshape mk
int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
size_t tmp_mk_nb[GGML_MAX_DIMS];
tmp_mk_nb[0] = faElemSize;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
}
aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
tmp_mk_base_buffer, faDataType, faElemSize,
tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
ACL_FORMAT_ND);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
tmp_arange_tensor, tmp_mk_tensor);
ggml_cann_release_resources(ctx, slope_tensor);
}
}
+36
View File
@@ -337,6 +337,29 @@ private:
int32_t device_;
};
#ifdef USE_ACL_GRAPH
struct ggml_graph_node_properties {
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
struct ggml_cann_graph {
~ggml_cann_graph() {
if (graph != nullptr) {
aclmdlRIDestroy(graph);
}
}
aclmdlRI graph = nullptr;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
};
#endif // USE_ACL_GRAPH
/**
* @brief Context for managing CANN backend operations.
*/
@@ -345,8 +368,13 @@ struct ggml_backend_cann_context {
std::string name; /**< Name of the device. */
std::string description; /**< Description of the device. */
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr<ggml_cann_graph> cann_graph;
#endif
cann_task_queue task_queue;
bool async_mode;
bool support_set_rows;
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
@@ -362,6 +390,14 @@ struct ggml_backend_cann_context {
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF");
support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
if (!support_set_rows) {
GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
"Falling back to eager mode.\n", __func__);
}
}
/**
+191 -31
View File
@@ -2075,6 +2075,160 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
}
#ifdef USE_ACL_GRAPH
/**
* @brief Populate the internal CANN graph node properties from the ggml computation graph.
*
* This function copies all node attributes (operation type, dimensions, strides, input sources,
* and operation parameters) into the cached CANN graph structure for later reuse or comparison.
*
* @param cann_ctx The CANN backend context.
* @param cgraph The ggml computational graph.
*/
static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
ggml_tensor * node = cgraph->nodes[node_idx];
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data;
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op;
for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim];
cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim];
}
for (int src = 0; src < GGML_MAX_SRC; src++) {
cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] =
node->src[src] ? node->src[src]->data : nullptr;
}
memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
}
/**
* @brief Check if a ggml tensor node matches a previously captured CANN graph node.
*
* This function compares all relevant fields (address, op type, shape, source inputs, op params)
* to determine whether the current node matches a previously recorded version.
*
* @param node The current ggml tensor node.
* @param graph_node_properties The stored properties of a CANN graph node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_VIEW) {
return false;
}
if (node->op != graph_node_properties->node_op) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
}
if (node->nb[i] != graph_node_properties->nb[i]) {
return false;
}
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW
) {
return false;
}
}
if (node->op == GGML_OP_SCALE &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}
return true;
}
/**
* @brief Determine if the CANN graph needs to be rebuilt due to graph changes.
*
* This checks whether the number or properties of ggml graph nodes have changed
* compared to the last captured CANN graph. If so, the CANN graph must be re-captured.
*
* @param cann_ctx The CANN backend context.
* @param cgraph The current ggml computation graph.
* @return true if an update is required; false otherwise.
*/
static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
// The number of nodes is different, so the graph needs to be reconstructed.
if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes);
return true;
}
// The number of nodes is the same; iterate over each node to check whether they match.
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = ggml_graph_node_has_matching_properties(
cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]);
if(!has_matching_properties) {
return true;
}
}
return false;
}
#endif // USE_ACL_GRAPH
/**
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
*
* If CANN graph execution is enabled and graph capture is required, this function begins
* graph capture, runs the graph, ends capture, and stores the captured graph.
*
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
*
* @param cann_ctx The CANN backend context.
* @param cgraph The ggml computation graph.
* @param use_cann_graph Whether to use CANN graph execution.
* @param cann_graph_update_required Whether graph capture is needed due to graph changes.
*/
static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
bool & use_cann_graph, bool & cann_graph_update_required) {
#ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_update_required) {
if (cann_ctx->cann_graph->graph != nullptr) {
ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph));
cann_ctx->cann_graph->graph = nullptr;
}
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
}
#endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
if (!use_cann_graph || cann_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
}
}
#ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph));
}
if (use_cann_graph) {
// Execute graph
ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream()));
}
#endif // USE_ACL_GRAPH
}
/**
* @brief Computes a computational graph using a CANN backend.
*
@@ -2091,27 +2245,38 @@ static enum ggml_status ggml_backend_cann_graph_compute(
ggml_backend_t backend, ggml_cgraph* cgraph) {
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device);
//release temp buffer create by set tensor.
release_nz_workspace();
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
bool cann_graph_update_required = false;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor* node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
continue;
}
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
// check environment LLAMA_SET_ROWS
if (!cann_ctx->support_set_rows) {
use_cann_graph = false;
}
if (use_cann_graph) {
if (cann_ctx->cann_graph == nullptr) {
cann_ctx->cann_graph.reset(new ggml_cann_graph());
cann_graph_update_required = true;
}
cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
set_ggml_graph_node_properties(cann_ctx, cgraph);
}
#else
bool use_cann_graph = false;
bool cann_graph_update_required = false;
#endif // USE_ACL_GRAPH
evaluate_and_capture_cann_graph(
cann_ctx,
cgraph,
use_cann_graph,
cann_graph_update_required
);
return GGML_STATUS_SUCCESS;
}
@@ -2226,12 +2391,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// only support F32 and F16.
return false;
}
if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
// unsupport dst is not contiguous.
return false;
}
return true;
} break;
case GGML_OP_CONT: {
@@ -2297,8 +2456,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// value of paddingW should be at most half of kernelW
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
}
case GGML_OP_SUM:
case GGML_OP_DUP:
case GGML_OP_SUM:
case GGML_OP_IM2COL:
case GGML_OP_CONCAT:
case GGML_OP_REPEAT:
@@ -2340,9 +2499,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
// TODO: support attention sinks [TAG_ATTN_SINKS]
if (op->src[2]) {
return false;
}
return true;
case GGML_OP_FLASH_ATTN_EXT:{
// derived from [ggml-cuda.cu]
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
@@ -2354,6 +2515,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
return false;
}
// TODO: support attention sinks [TAG_ATTN_SINKS]
if (op->src[4]) {
return false;
}
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
@@ -2365,11 +2530,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// DeepSeek MLA
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
if (op->src[0]->ne[3] != 1) {
return false;
}
float logitSoftcap = 0.0f;
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
if(logitSoftcap != 0.0f) {
+17
View File
@@ -99,6 +99,9 @@ typedef sycl::half2 ggml_half2;
#define QI4_1 (QK4_1 / (4 * QR4_1))
#define QR4_1 2
#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
#define QR_MXFP4 2
#define QI5_0 (QK5_0 / (4 * QR5_0))
#define QR5_0 2
@@ -184,6 +187,13 @@ typedef struct {
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK_MXFP4 32
typedef struct {
uint8_t e; // E8M0
uint8_t qs[QK_MXFP4/2];
} block_mxfp4;
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
#define QK5_0 32
typedef struct {
ggml_half d; // delta
@@ -1074,10 +1084,17 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
GGML_TABLE_END()
// TODO: fix name to kvalues_iq4_nl
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
// e2m1 values (doubled)
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
GGML_TABLE_END()
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
+6
View File
@@ -13,6 +13,7 @@
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -68,6 +69,7 @@
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -90,6 +92,7 @@
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -120,6 +123,7 @@
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -149,6 +153,7 @@
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -179,6 +184,7 @@
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
+61
View File
@@ -589,6 +589,67 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
*s = sumf;
}
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_MXFP4 == 0);
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
const block_mxfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP4;
int ib = 0;
float sumf = 0;
#if defined __ARM_NEON
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
uint8x16x2_t q4bits;
int8x16x4_t q4b;
int8x16x4_t q8b;
int32x4_t prod_1;
int32x4_t prod_2;
for (; ib + 1 < nb; ib += 2) {
q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
q8b.val[0] = vld1q_s8(y[ib + 0].qs);
q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
q8b.val[2] = vld1q_s8(y[ib + 1].qs);
q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
sumf +=
GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
}
#endif
for (; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+96 -8
View File
@@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) {
}
#if defined(__AVX2__) || defined(__AVX512F__)
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x);
const __m256i sy = _mm256_sign_epi8(y, x);
return _mm256_maddubs_epi16(ax, sy);
}
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
@@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
}
static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
}
#endif
#elif defined(__SSSE3__)
// horizontally add 4x4 floats
@@ -746,6 +757,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
}
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_MXFP4 == 0);
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
const block_mxfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP4;
int ib = 0;
float sumf = 0;
#if defined __AVX2__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
_mm256_cvtepi32_ps(p_1), accum1);
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
_mm256_cvtepi32_ps(p_2), accum2);
}
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined __AVX__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
__m256 accum = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
}
sumf = hsum_float_8(accum);
#endif
for (; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
@@ -3206,14 +3302,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
#endif
}
#if defined(__AVX2__)
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x);
const __m256i sy = _mm256_sign_epi8(y, x);
return _mm256_maddubs_epi16(ax, sy);
}
#endif
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
+14 -1
View File
@@ -253,6 +253,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1,
},
[GGML_TYPE_MXFP4] = {
.from_float = quantize_row_mxfp4,
.vec_dot = ggml_vec_dot_mxfp4_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q2_K] = {
.from_float = quantize_row_q2_K,
.vec_dot = ggml_vec_dot_q2_K_q8_K,
@@ -1670,6 +1676,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_add(params, tensor);
} break;
case GGML_OP_ADD_ID:
{
ggml_compute_forward_add_id(params, tensor);
} break;
case GGML_OP_ADD1:
{
ggml_compute_forward_add1(params, tensor);
@@ -1924,7 +1934,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
ggml_compute_forward_flash_attn_ext(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
@@ -2111,6 +2121,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_DUP:
case GGML_OP_CONT:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_ACC:
{
@@ -2172,6 +2183,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
{
@@ -2673,6 +2685,7 @@ struct ggml_cplan ggml_graph_plan(
}
} break;
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
{
if (ggml_is_quantized(node->src[0]->type)) {
+21 -24
View File
@@ -35,7 +35,7 @@
// ggml-backend interface
std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type() {
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types() {
static std::vector<ggml_backend_buffer_type_t> bufts = []() {
std::vector<ggml_backend_buffer_type_t> bufts;
@@ -57,8 +57,6 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
}
#endif
bufts.push_back(NULL);
return bufts;
}();
@@ -66,14 +64,20 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
}
static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
return ggml_backend_cpu_get_extra_buffers_type().data();
static std::vector<ggml_backend_buffer_type_t> extra_bufts = [] {
std::vector<ggml_backend_buffer_type_t> bufts = ggml_backend_cpu_get_extra_buffer_types();
bufts.push_back(nullptr);
return bufts;
}();
return extra_bufts.data();
GGML_UNUSED(device);
}
static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra && extra == buft) {
for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra == buft) {
return true;
}
}
@@ -210,10 +214,10 @@ ggml_backend_t ggml_backend_cpu_init(void) {
ctx->abort_callback_data = NULL;
ggml_backend_t cpu_backend = new ggml_backend {
/* .guid = */ ggml_backend_cpu_guid(),
/* .interface = */ ggml_backend_cpu_i,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ ctx,
/* .guid = */ ggml_backend_cpu_guid(),
/* .iface = */ ggml_backend_cpu_i,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
/* .context = */ ctx,
};
if (cpu_backend == NULL) {
@@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
return true;
}
// extra_buffer_op?
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
if (extra) {
auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
if (buf_extra && buf_extra->supports_op(dev, op)) {
return true;
}
}
}
// the other case need host buffer.
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
return false;
// check extra buffer types
// note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
for (int i = 0; i < 4; i++) {
if (op->src[i] && op->src[i]->buffer &&
ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {
auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;
return buf_extra->supports_op(dev, op);
}
}
+16 -7
View File
@@ -259,7 +259,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t m_start = 0;
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
const int64_t num_threads = KAI_MIN(n / n_step, nth);
int64_t num_threads = KAI_MIN(n / n_step, nth);
if (num_threads <= 0) {
num_threads = 1;
}
if (ith < num_threads) {
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
@@ -309,7 +312,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
GGML_ASSERT(kernel);
const int ith = params->ith;
const int nth = params->nth;
const int nth_raw = params->nth;
const int nth = nth_raw > 0 ? nth_raw : 1;
const size_t k = ne00;
const size_t m = ne11;
@@ -327,9 +331,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
const size_t n_start = ith * num_n_per_thread;
size_t n_to_process = num_n_per_thread;
if ((n_start + n_to_process) > n) {
n_to_process = n - n_start;
size_t n_to_process = 0;
if (n_start < n) {
n_to_process = num_n_per_thread;
if ((n_start + n_to_process) > n) {
n_to_process = n - n_start;
}
}
// Calculate number of columns to be processed per thread
@@ -361,8 +368,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
sizeof(float), -FLT_MAX, FLT_MAX);
if (n_to_process > 0) {
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
sizeof(float), -FLT_MAX, FLT_MAX);
}
return true;
}
+207 -9
View File
@@ -8,6 +8,7 @@
#include "vec.h"
#include <float.h>
#include <algorithm>
// ggml_compute_forward_dup
@@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
}
}
// ggml_compute_forward_add_id
static void ggml_compute_forward_add_id_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(src0);
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
// src1 indices
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
GGML_ASSERT(i11 >= 0 && i11 < ne11);
ggml_vec_add_f32(ne0,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
(float *) ((char *) src1->data + i11*nb11));
}
}
void ggml_compute_forward_add_id(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_add_id_f32(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
}
}
}
// ggml_compute_forward_add1
static void ggml_compute_forward_add1_f32(
@@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
}
}
// ggml_compute_forward_swiglu_oai
static void ggml_compute_forward_swiglu_oai_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
for (int k = 0; k < nc; k++) {
const float x = std::min(src0_p[k], limit);
const float y = std::clamp(src1_p[k], -limit, limit);
const float out_glu = x / (1.f + expf(alpha * (-x)));
dst_p[k] = out_glu * (y + 1.f);
}
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = dst_p[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_swiglu_oai(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_swiglu_oai_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_geglu_erf
static void ggml_compute_forward_geglu_erf_f32(
@@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -4873,6 +5036,7 @@ void ggml_compute_forward_set(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst));
@@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
// sinks
const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
@@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
float max = -INFINITY;
ggml_vec_max_f32(ne00, &max, wp);
// if we have sinks, make a correction as if they were included in the softmax
if (sk) {
max = MAX(max, sk[i02]);
}
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
assert(sum > 0.0);
if (sk) {
sum += (ggml_float) expf(sk[i02] - max);
}
sum = 1.0/sum;
ggml_vec_scale_f32(ne00, dp, sum);
@@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -7989,12 +8168,14 @@ void ggml_compute_forward_argsort(
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
const ggml_tensor * q,
const ggml_tensor * k,
const ggml_tensor * v,
const ggml_tensor * mask,
ggml_tensor * dst) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
@@ -8189,6 +8370,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
}
// sinks
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
vs = expf(s - M);
}
S = S*ms + vs;
}
// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
@@ -8208,17 +8406,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
void ggml_compute_forward_flash_attn_ext(
const ggml_compute_params * params,
const ggml_tensor * q,
const ggml_tensor * k,
const ggml_tensor * v,
const ggml_tensor * mask,
ggml_tensor * dst) {
switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{
// uses F32 accumulators
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
ggml_compute_forward_flash_attn_ext_f16(params, dst);
} break;
default:
{
@@ -9080,6 +9274,10 @@ void ggml_compute_forward_glu(
{
ggml_compute_forward_swiglu(params, dst);
} break;
case GGML_GLU_OP_SWIGLU_OAI:
{
ggml_compute_forward_swiglu_oai(params, dst);
} break;
case GGML_GLU_OP_GEGLU_ERF:
{
ggml_compute_forward_geglu_erf(params, dst);
+2 -7
View File
@@ -29,6 +29,7 @@ extern "C" {
void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -82,13 +83,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_ext(
const struct ggml_compute_params * params,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_back(
const struct ggml_compute_params * params,
const bool masked,
+35
View File
@@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
quantize_row_q8_1_ref(x, y, k);
}
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_mxfp4_ref(x, y, k);
}
//
// 2-6 bit quantization in super-blocks
//
@@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
*s = sumf;
}
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_MXFP4 == 0);
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
const block_mxfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_MXFP4;
int ib = 0;
float sumf = 0;
for (; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
int sumi1 = 0;
int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) {
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+8
View File
@@ -19,6 +19,8 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -39,6 +41,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -67,8 +71,12 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+2 -2
View File
@@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {}
} // namespace ggml::cpu
bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);
@@ -23,7 +23,7 @@ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct
}
bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {
for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
if (extra && extra->context) {
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
auto tensor_traits = buf_extra->get_tensor_traits(op);
+1 -1
View File
@@ -33,6 +33,6 @@ class extra_buffer_type {
} // namespace ggml::cpu
// implemented in ggml-cpu.cpp.
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffers_type();
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types();
#endif
+19 -4
View File
@@ -55,7 +55,22 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
int i = 0;
#if defined(__AVX2__)
for (; i + 7 < n; i += 8) {
__m256 vx = _mm256_loadu_ps(x + i);
__m256 vy = _mm256_loadu_ps(y + i);
__m256 vz = _mm256_add_ps(vx, vy);
_mm256_storeu_ps(z + i, vz);
}
#endif
for (; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
for (int i = 0; i < n; ++i) {
z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
@@ -992,9 +1007,9 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
for (int i = 0; i < n; ++i) {
float v = GGML_CPU_FP16_TO_FP32(x[i]);
float w = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
}
}
+4
View File
@@ -120,6 +120,10 @@ if (CUDAToolkit_FOUND)
set(CUDA_FLAGS -use_fast_math -extended-lambda)
if (GGML_CUDA_DEBUG)
list(APPEND CUDA_FLAGS -lineinfo)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
# Options are:
# - none (not recommended)
+58
View File
@@ -0,0 +1,58 @@
#include "add-id.cuh"
static __global__ void add_id_kernel(
const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne0, int64_t ne1,
size_t nb01, size_t nb02,
size_t nb11,
size_t nb21
) {
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
const size_t nb1 = ne0 * sizeof(float);
const size_t nb2 = ne1 * nb1;
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(int32_t));
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
const int32_t * src2_d = (const int32_t *)src2->data;
float * dst_d = (float *)dst->data;
int threads = std::min((int)ne00, 768); // cols
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src0_d, src1_d, src2_d, dst_d,
ne0, ne1,
nb01, nb02,
nb11,
nb21
);
}
+3
View File
@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+37 -3
View File
@@ -1,6 +1,7 @@
#pragma once
#include "ggml.h"
#include "ggml-impl.h"
#include "ggml-cuda.h"
#include <cstdint>
@@ -232,9 +233,13 @@ typedef float2 dfloat2;
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#define TURING_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -302,12 +307,16 @@ static bool amd_mfma_available(const int cc) {
}
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) {
static bool turing_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
static bool ampere_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static bool cp_async_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
@@ -549,6 +558,24 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#endif // defined(GGML_USE_HIP)
}
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#if CUDART_VERSION >= 12080
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
return (float) e;
#else
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = (uint32_t) x << 23;
}
float result;
memcpy(&result, &bits, sizeof(float));
return result;
#endif // CUDART_VERSION >= 12050
}
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __device__ __forceinline__ float get_alibi_slope(
@@ -607,6 +634,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
static constexpr int qi = QI8_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
static constexpr int qk = QK_MXFP4;
static constexpr int qr = QR_MXFP4;
static constexpr int qi = QI_MXFP4;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
static constexpr int qk = QK_K;
+28
View File
@@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
}
}
template<typename dst_t>
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
const int64_t tid = threadIdx.x;
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il;
const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
}
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
@@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
+4 -1
View File
@@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -736,7 +737,8 @@ void launch_fattn(
GGML_ASSERT(V || is_mla);
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst;
@@ -940,6 +942,7 @@ void launch_fattn(
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *) sinks->data) : nullptr,
KV_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
+79 -24
View File
@@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
float * const __restrict__ KQ_max,
float * const __restrict__ KQ_rowsum,
const int kb0) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
typedef fattn_mma_f16_config<DKQ, DV> c;
#ifdef CP_ASYNC_AVAILABLE
@@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
@@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half2 * const __restrict__ mask_h2,
const float * const __restrict__ sinks_f,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
@@ -800,7 +801,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jt,
const int kb0_start,
const int kb0_stop) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
typedef fattn_mma_f16_config<DKQ, DV> c;
@@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}
// If attention sinks are used, potentially re-scale if KQ_max is small.
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
// so it's being done unconditionally for every thread.
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
KQ_max_scale[col] = expf(KQ_max_diff);
KQ_max[col] = KQ_max_new;
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
const float KQ_max_add = expf(sink - KQ_max_new);
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
}
if (ntiles == 1) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
} else {
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
#pragma unroll
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
}
}
}
}
}
// Combine VKQ accumulator values if np > 1.
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
@@ -1196,7 +1243,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
@@ -1206,6 +1253,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -1222,7 +1270,7 @@ static __global__ void flash_attn_ext_f16(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1267,20 +1315,24 @@ static __global__ void flash_attn_ext_f16(
// kb0 == k start index when in the output tile.
int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const int head0 = zt * ncols2;
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1293,12 +1345,12 @@ static __global__ void flash_attn_ext_f16(
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
@@ -1314,18 +1366,21 @@ static __global__ void flash_attn_ext_f16(
}
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const int head0 = zt * ncols2;
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1337,10 +1392,10 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
@@ -1352,7 +1407,7 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
}
template <int DKQ, int DV, int ncols1, int ncols2>
+32 -5
View File
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -48,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
@@ -241,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const half sink = __float2half(sinksf[head]);
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
kqmax[j0/nwarps] = kqmax_new_j;
const half val = hexp(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
@@ -272,7 +299,7 @@ static __global__ void flash_attn_tile_ext_f16(
}
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+34 -5
View File
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -37,7 +38,7 @@ static __global__ void flash_attn_tile_ext_f32(
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
@@ -59,10 +60,11 @@ static __global__ void flash_attn_tile_ext_f32(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
@@ -251,6 +253,33 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
kqmax[j0/nwarps] = kqmax_new_j;
const float val = expf(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps] += val;
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
+39 -3
View File
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -61,7 +62,8 @@ static __global__ void flash_attn_vec_ext_f16(
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
@@ -75,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
half2 * KQ2 = (half2 *) KQ;
half kqmax[ncols];
half kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -HALF_MAX_HALF;
kqsum[j] = 0.0f;
}
half kqsum[ncols] = {0.0f};
__shared__ half kqmax_shared[ncols][WARP_SIZE];
__shared__ half kqsum_shared[ncols][WARP_SIZE];
@@ -283,6 +286,39 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads();
}
if (sinksf && blockIdx.y == 0) {
const half sink = __float2half(sinksf[head]);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const half val = hexp(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= __half2half2(KQ_max_scale);
}
__syncthreads();
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
@@ -313,7 +349,7 @@ static __global__ void flash_attn_vec_ext_f16(
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+38 -2
View File
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -72,7 +73,8 @@ static __global__ void flash_attn_vec_ext_f32(
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
@@ -88,11 +90,12 @@ static __global__ void flash_attn_vec_ext_f32(
}
float kqmax[ncols];
float kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -FLT_MAX/2.0f;
kqsum[j] = 0.0f;
}
float kqsum[ncols] = {0.0f};
__shared__ float kqmax_shared[ncols][WARP_SIZE];
__shared__ float kqsum_shared[ncols][WARP_SIZE];
@@ -279,6 +282,39 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads();
}
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const float val = expf(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= KQ_max_scale;
}
__syncthreads();
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]);
+55 -6
View File
@@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
@@ -81,11 +82,12 @@ static __global__ void flash_attn_ext_f16(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh;
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh;
const float * sinksf = (const float *) sinks;
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
@@ -380,6 +382,53 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();
}
// Apply attention sinks
if (sinksf && blockIdx.y == 0) {
const float sinkf = sinksf[head];
const half sinkh = __float2half(sinkf);
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
KQ_max_f[j0/nwarps] = kqmax_new;
KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) break;
VKQ2[j*(D_padded/2) + i] *= scale_h2;
}
} else {
half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
half kqmax_new = fmaxf(kqmax_old, sinkh);
KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
const half val = hexp(sinkh - kqmax_new);
KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + warp_size > D/2 && i >= D/2) break;
VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
}
}
}
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
@@ -423,7 +472,7 @@ static __global__ void flash_attn_ext_f16(
dst_meta[j_dst_unrolled] = dst_meta_val;
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+7 -7
View File
@@ -269,11 +269,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -316,7 +316,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
@@ -329,7 +329,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (fp16_mma_available(cc) && !new_mma_available(cc)) {
if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
+49 -16
View File
@@ -4,6 +4,7 @@
#include "ggml-cuda/common.cuh"
#include "ggml-cuda/acc.cuh"
#include "ggml-cuda/add-id.cuh"
#include "ggml-cuda/arange.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/argsort.cuh"
@@ -21,8 +22,9 @@
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmf.cuh"
#include "ggml-cuda/mmq.cuh"
#include "ggml-cuda/mmv.cuh"
#include "ggml-cuda/mmvf.cuh"
#include "ggml-cuda/mmvq.cuh"
#include "ggml-cuda/norm.cuh"
#include "ggml-cuda/opt-step-adamw.cuh"
@@ -2007,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool use_mul_mat_f = !ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2027,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
}
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
@@ -2047,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
//TODO update for generic tensor parallelism
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
if (!split && use_mul_mat_vec) {
if (!split && use_mul_mat_vec_f) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_f) {
ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_vec_q) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_q) {
@@ -2064,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// general KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_mul_mat_vec) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_f) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
} else if (use_mul_mat_vec_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {
@@ -2093,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
} else {
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
}
return;
}
@@ -2259,6 +2269,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ADD1: // TODO: more efficient implementation
ggml_cuda_op_add(ctx, dst);
break;
case GGML_OP_ADD_ID:
ggml_cuda_op_add_id(ctx, dst);
break;
case GGML_OP_SUB:
ggml_cuda_op_sub(ctx, dst);
break;
@@ -2333,6 +2346,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_GLU_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst);
break;
case GGML_GLU_OP_SWIGLU_OAI:
ggml_cuda_op_swiglu_oai(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_cuda_op_geglu_erf(ctx, dst);
break;
@@ -2607,6 +2623,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2629,7 +2648,13 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#endif
}
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
@@ -3227,6 +3252,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
@@ -3277,6 +3303,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -3423,6 +3450,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -3497,12 +3525,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (!new_mma_available(cc)) {
if (!turing_mma_available(cc)) {
return false;
}
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
}
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
@@ -3766,10 +3799,10 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
}
ggml_backend_t cuda_backend = new ggml_backend {
/* .guid = */ ggml_backend_cuda_guid(),
/* .interface = */ ggml_backend_cuda_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
/* .context = */ ctx,
/* .guid = */ ggml_backend_cuda_guid(),
/* .iface = */ ggml_backend_cuda_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
/* .context = */ ctx,
};
return cuda_backend;
+3 -2
View File
@@ -1,7 +1,5 @@
#include "im2col.cuh"
#define MIN(a, b) (a) < (b) ? (a) : (b)
#define MAX_GRIDDIM_Z 65535
template <typename T>
@@ -38,6 +36,9 @@ static __global__ void im2col_kernel(
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
GGML_UNUSED(IC);
GGML_UNUSED(KH);
}
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+88 -22
View File
@@ -23,13 +23,13 @@
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(ret) : "r"(x));
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
#endif // defined(TURING_MMA_AVAILABLE)
return ret;
}
@@ -167,6 +167,38 @@ namespace ggml_cuda_mma {
}
};
template <int I_, int J_>
struct tile<I_, J_, nv_bfloat162> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 8 && J == 8) {
return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return l * 8 + threadIdx.x / 4;
} else if constexpr (I == 16 && J == 8) {
return (l % 2) * 8 + threadIdx.x / 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 8 && J == 8) {
return l * 4 + threadIdx.x % 4;
} else if constexpr (I == 16 && J == 4) {
return threadIdx.x % 4;
} else if constexpr (I == 16 && J == 8) {
return (l / 2) * 4 + threadIdx.x % 4;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
};
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
tile<I, J/2, half2> ret;
@@ -209,7 +241,7 @@ namespace ggml_cuda_mma {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -217,13 +249,13 @@ namespace ggml_cuda_mma {
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
int * xi = (int *) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -232,13 +264,13 @@ namespace ggml_cuda_mma {
#else
load_generic(xs0, stride);
GGML_UNUSED(t);
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#if defined(NEW_MMA_AVAILABLE)
#if defined(TURING_MMA_AVAILABLE)
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@@ -246,13 +278,13 @@ namespace ggml_cuda_mma {
: "l"(xs));
#else
load_generic(t, xs0, stride);
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
@@ -263,12 +295,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -287,12 +319,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -317,12 +349,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@@ -344,12 +376,12 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@@ -380,12 +412,29 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@@ -407,12 +456,29 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
int * Dxi = (int *) D.x;
@@ -443,7 +509,7 @@ namespace ggml_cuda_mma {
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma(
+431
View File
@@ -0,0 +1,431 @@
#include "ggml.h"
#include "common.cuh"
#include "mma.cuh"
#include "mmf.cuh"
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int channel_dst = blockIdx.y;
const int channel_x = channel_dst / channel_ratio;
const int channel_y = channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
tile_C C[ntA][ntB];
T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
if constexpr (std::is_same_v<T, float>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
}
}
float * buf_iw = (float *) data_mmv;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
}
#else
NO_DEVICE_CODE;
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst);
GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst);
GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst);
GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst);
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template <typename T, int cols_per_block>
static void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
GGML_ASSERT(!ids && "mul_mat_id not implemented");
GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
niter_best = niter;
nwarps_best = nwarps;
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
switch (nwarps_best) {
case 1: {
mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 2: {
mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 3: {
mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 4: {
mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 5: {
mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 6: {
mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 7: {
mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 8: {
mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
template <typename T>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT( nb00 == ts_src0);
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT( nb0 == ts_dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(!ids || ncols_dst == 1);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
}
}
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
return false;
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false;
}
if (ne11 > 16) {
return false;
}
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
case GGML_TYPE_F16:
return turing_mma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc);
default:
return false;
}
}
+5
View File
@@ -0,0 +1,5 @@
#include "common.cuh"
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
+5 -1
View File
@@ -20,6 +20,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break;
@@ -282,6 +285,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -306,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}
if (new_mma_available(cc)) {
if (turing_mma_available(cc)) {
return true;
}
+206 -128
View File
File diff suppressed because it is too large Load Diff
@@ -1,9 +1,9 @@
#include "ggml.h"
#include "common.cuh"
#include "mmv.cuh"
#include "mmvf.cuh"
template <typename T, typename type_acc, int ncols_dst, int block_size>
static __global__ void mul_mat_vec(
static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
@@ -37,7 +37,7 @@ static __global__ void mul_mat_vec(
float sumf[ncols_dst] = {0.0f};
if constexpr (std::is_same<T, float>::value) {
if constexpr (std::is_same_v<T, float>) {
const float2 * x2 = (const float2 *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
@@ -50,10 +50,10 @@ static __global__ void mul_mat_vec(
sumf[j] += tmpx.y*tmpy.y;
}
}
} else if constexpr (std::is_same<T, half>::value) {
} else if constexpr (std::is_same_v<T, half>) {
const half2 * x2 = (const half2 *) x;
if (std::is_same<type_acc, float>::value) {
if (std::is_same_v<type_acc, float>) {
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
@@ -86,7 +86,7 @@ static __global__ void mul_mat_vec(
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
const int * x2 = (const int *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
@@ -98,7 +98,7 @@ static __global__ void mul_mat_vec(
}
}
} else {
static_assert(std::is_same<T, void>::value, "unsupported type");
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
@@ -126,7 +126,7 @@ static __global__ void mul_mat_vec(
}
template <typename T, typename type_acc, int ncols_dst>
static void launch_mul_mat_vec_cuda(
static void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
@@ -141,11 +141,9 @@ static void launch_mul_mat_vec_cuda(
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
int device;
int warp_size;
CUDA_CHECK(cudaGetDevice(&device));
warp_size = ggml_cuda_info().devices[device].warp_size;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t block_size_best = warp_size;
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
@@ -161,54 +159,54 @@ static void launch_mul_mat_vec_cuda(
}
}
const int smem = warp_size*sizeof(float);
const int nbytes_shared = warp_size*sizeof(float);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
@@ -220,7 +218,7 @@ static void launch_mul_mat_vec_cuda(
}
template <typename T, typename type_acc>
static void mul_mat_vec_cuda_switch_ncols_dst(
static void mul_mat_vec_f_cuda_switch_ncols_dst(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
@@ -230,49 +228,49 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
cudaStream_t stream) {
switch (ncols_dst) {
case 1:
launch_mul_mat_vec_cuda<T, type_acc, 1>
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 2:
launch_mul_mat_vec_cuda<T, type_acc, 2>
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 3:
launch_mul_mat_vec_cuda<T, type_acc, 3>
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 4:
launch_mul_mat_vec_cuda<T, type_acc, 4>
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 5:
launch_mul_mat_vec_cuda<T, type_acc, 5>
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 6:
launch_mul_mat_vec_cuda<T, type_acc, 6>
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 7:
launch_mul_mat_vec_cuda<T, type_acc, 7>
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 8:
launch_mul_mat_vec_cuda<T, type_acc, 8>
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
@@ -284,7 +282,7 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
}
template<typename T>
static void mul_mat_vec_cuda(
static void mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
@@ -292,22 +290,22 @@ static void mul_mat_vec_cuda(
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same<T, half>::value) {
if constexpr(std::is_same_v<T, half>) {
if (prec == GGML_PREC_DEFAULT) {
mul_mat_vec_cuda_switch_ncols_dst<T, half>
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return;
}
}
mul_mat_vec_cuda_switch_ncols_dst<T, float>
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -355,19 +353,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
@@ -376,7 +374,7 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
}
}
void ggml_cuda_op_mul_mat_vec(
void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
@@ -414,19 +412,19 @@ void ggml_cuda_op_mul_mat_vec(
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
@@ -442,15 +440,15 @@ void ggml_cuda_op_mul_mat_vec(
GGML_UNUSED(src1_padded_row_size);
}
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
if (src0_ne[0] % 2 != 0) {
return false;
}
switch (type) {
case GGML_TYPE_F32:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return ne11 <= 8;
if (ampere_mma_available(cc)) {
return ne11 <= 3;
}
if (cc >= GGML_CUDA_CC_TURING) {
return ne11 <= 4;
@@ -466,6 +464,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
case GGML_TYPE_F16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (ampere_mma_available(cc)) {
return src0_small && ne11 == 1;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
@@ -486,6 +487,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
case GGML_TYPE_BF16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (ampere_mma_available(cc)) {
return src0_small && ne11 == 1;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
@@ -1,11 +1,11 @@
#include "common.cuh"
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec(
void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
+9
View File
@@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
@@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
@@ -384,6 +386,13 @@ static void mul_mat_vec_q_switch_type(
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+16 -10
View File
@@ -45,7 +45,7 @@ struct soft_max_params {
#endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
const float * x, const T * mask, float * dst, const soft_max_params p) {
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
const int tid = threadIdx.x;
@@ -77,7 +77,7 @@ static __global__ void soft_max_f32(
// shared memory buffer to cache values between iterations:
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
float max_val = -INFINITY;
float max_val = sinks ? sinks[i02] : -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -143,6 +143,10 @@ static __global__ void soft_max_f32(
tmp = warp_reduce_sum(tmp);
}
if (sinks) {
tmp += expf(sinks[i02] - max_val);
}
const float inv_sum = 1.0f / tmp;
#pragma unroll
@@ -183,7 +187,7 @@ static __global__ void soft_max_back_f32(
}
template<int... Ns, typename T>
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
{
const int id = ggml_cuda_get_device();
@@ -196,7 +200,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
if (p.ncols == ncols) {
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, p);
(x, mask, sinks, dst, p);
return true;
}
return false;
@@ -209,12 +213,12 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
//default case
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
}
template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
@@ -230,10 +234,10 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
}
@@ -249,9 +253,11 @@ static void soft_max_back_f32_cuda(
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *) src0->data;
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
@@ -309,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
params.m1 = m1;
if (use_f16) {
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
} else {
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
}
}
+153 -71
View File
@@ -1,87 +1,117 @@
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
#define USE_CUB
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
#ifdef USE_CUB
#include <cub/cub.cuh>
using namespace cub;
#endif // USE_CUB
#include "ssm-scan.cuh"
template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
// We would like to keep pragma unroll for cases where L_template is not 0,
// so we suppress the clang transformation warning.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <size_t splitD, size_t N, size_t L_template>
__global__ void __launch_bounds__(splitD, 1)
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t d_inner, const int64_t L) {
const int64_t s_off, const int64_t d_inner, const int64_t L_param)
{
const size_t L = L_template == 0 ? L_param : L_template;
const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int bidx = blockIdx.x; // split along B (sequences)
const int bidy = blockIdx.y; // split along D (d_inner)
const int tid = threadIdx.x;
const int wid = tid / 32;
const int wtid = tid % 32;
extern __shared__ float smem[];
const int stride_sA = N + 1;
const int stride_ss0 = N + 1;
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
const int stride_s0 = src0_nb2 / sizeof(float);
const int stride_x = src1_nb2 / sizeof(float);
const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_s = stride_s0;
const int stride_y = d_inner;
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = d_inner;
// can N not be 16? for example 32?
if (N == 16) {
float regA[N];
float regs0[N];
__shared__ float smemB[N];
__shared__ float smemC[N];
#ifdef USE_CUB
using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
union CubTempStorage {
typename BlockLoad::TempStorage load_temp;
typename BlockStore::TempStorage store_temp;
};
__shared__ CubTempStorage cub_temp_storage;
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
#else
const int stride_s0 = src0_nb2 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
#pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
// todo: bank conflict
// I am always confused with how to use the swizzling method to solve
// bank conflit. Hoping somebody can tell me.
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
#pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
for (size_t n = 0; n < N; ++n)
{
regA[n] = A_block[threadIdx.x * stride_A + n];
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
}
#endif
__syncthreads();
for (int64_t i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus));
}
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float sumf = 0.0f;
#pragma unroll
for (size_t j = 0; j < N; j++) {
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j];
if (i == L - 1) {
s_block[tid * stride_s + j] = state;
} else {
smem_s0[tid * stride_ss0 + j] = state;
}
for (size_t i = 0; i < L; i++)
{
if (threadIdx.x < N)
{
smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
}
__syncthreads();
y_block[i * stride_y + tid] = sumf;
float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
if (dt_soft_plus <= 20.0f)
{
dt_soft_plus = log1pf(expf(dt_soft_plus));
}
float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
float sumf = 0.0f;
#pragma unroll
for (size_t n = 0; n < N; n++)
{
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
sumf += state * smemC[n];
regs0[n] = state;
}
y_block[i * stride_y + threadIdx.x] = sumf;
}
#ifdef USE_CUB
BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
#else
const int stride_s = stride_s0;
#pragma unroll
for (size_t n = 0; n < N; ++n)
{
s_block[threadIdx.x * stride_s + n] = regs0[n];
}
#endif
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
// assumes as many threads as d_state
template <int splitH, int d_state>
@@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
const int threads = 128;
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
const int splitH = 16;
@@ -229,7 +259,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
GGML_ABORT("doesn't support d_state!=(128 or 256).");
}
} else {
const int threads = 128;
// Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
@@ -237,10 +266,63 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
switch (n_tok)
{
case 1:
ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 2:
ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 3:
ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 4:
ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 5:
ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 6:
ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 7:
ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 8:
ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
default:
ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
}
} else {
GGML_ABORT("doesn't support d_state!=16.");
}
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_MXFP4);
+75
View File
@@ -300,6 +300,81 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
}
// swiglu_oai
template <typename T>
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
// perform base op and multiply with gate (either offset in same tensor or a separate one)
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
float xi = x[j0];
float gi = g[j1];
xi = fminf(xi, limit);
gi = fmaxf(fminf(gi, limit), -limit);
float out_glu = xi / (1.0f + expf(-xi * alpha));
out_glu = out_glu * (1.0f + gi);
dst[i] = out_glu;
}
template <typename T>
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
}
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
void * src0_d = src0->data;
void * src1_d = src1 ? src1->data : src0->data;
const int64_t src0_o = src0->nb[1];
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
void * dst_d = dst->data;
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == dst->type);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
GGML_ASSERT(src1->ne[0] == nc);
GGML_ASSERT(src0->type == src1->type);
}
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
float * src0_p = (float *) src0_d;
float * src1_p = (float *) src1_d;
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
}
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
+2
View File
@@ -67,6 +67,8 @@ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+52 -16
View File
@@ -1,8 +1,20 @@
#pragma once
#include "common.cuh"
#include <cstdint>
static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
const uint8_t * x8 = (const uint8_t *) x;
int x32 = x8[4*i32 + 0] << 0;
x32 |= x8[4*i32 + 1] << 8;
x32 |= x8[4*i32 + 2] << 16;
x32 |= x8[4*i32 + 3] << 24;
return x32;
}
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
@@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
const int8_t * q1_8 = (const int8_t *) &q1_32;
const char4 val1_8 = make_char4(
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
@@ -211,6 +237,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
return d8_1*sumf;
}
#define VDR_MXFP4_Q8_1_MMVQ 2
#define VDR_MXFP4_Q8_1_MMQ 4
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
const int * q8 = (const int *) bq8_1->qs + iqs;
int sumi = 0;
#pragma unroll
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
}
const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
return d * sumi;
}
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 4
@@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
const int8_t * q1_8 = (const int8_t *) &q1_32;
const char4 val1_8 = make_char4(
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
}
#define VDR_IQ4_NL_Q8_1_MMVQ 2
#define VDR_IQ4_NL_Q8_1_MMQ 4
@@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
const int2 v = get_int_from_table_16(aux_q4);
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
@@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
const int2 v = get_int_from_table_16(aux_q4);
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
+4
View File
@@ -6,6 +6,10 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDART_VERSION >= 12050
#include <cuda_fp8.h>
#endif // CUDART_VERSION >= 12050
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
+1
View File
@@ -200,6 +200,7 @@
#endif
typedef hip_bfloat16 nv_bfloat16;
typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
+2 -1
View File
@@ -137,4 +137,5 @@
#define cudaStreamEndCapture musaStreamEndCapture
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
typedef mt_bfloat16 nv_bfloat16;
typedef __mt_bfloat16 nv_bfloat16;
typedef __mt_bfloat162 nv_bfloat162;
+4
View File
@@ -121,6 +121,10 @@ if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
endif()
if (GGML_HIP_EXPORT_METRICS)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
endif()
if (NOT GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_NO_FA)
endif()
+61
View File
@@ -410,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
static inline float ggml_e8m0_to_fp32(uint8_t x) {
uint32_t bits; // Stores the raw bit representation of the float
// Handle special case for minimum exponent (denormalized float)
if (x == 0) {
// Bit pattern for 2^(-127):
// - Sign bit: 0 (positive)
// - Exponent: 0 (denormalized number)
// - Mantissa: 0x400000 (0.5 in fractional form)
// Value = 0.5 * 2^(-126) = 2^(-127)
bits = 0x00400000;
}
// note: disabled as we don't need to handle NaNs
//// Handle special case for NaN (all bits set)
//else if (x == 0xFF) {
// // Standard quiet NaN pattern:
// // - Sign bit: 0
// // - Exponent: all 1s (0xFF)
// // - Mantissa: 0x400000 (quiet NaN flag)
// bits = 0x7FC00000;
//}
// Normalized values (most common case)
else {
// Construct normalized float by shifting exponent into position:
// - Exponent field: 8 bits (positions 30-23)
// - Mantissa: 0 (implicit leading 1)
// Value = 2^(x - 127)
bits = (uint32_t) x << 23;
}
float result; // Final float value
// Safely reinterpret bit pattern as float without type-punning issues
memcpy(&result, &bits, sizeof(float));
return result;
}
// Equal to ggml_e8m0_to_fp32/2
// Useful with MXFP4 quantization since the E0M2 values are doubled
static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
uint32_t bits;
// For x < 2: use precomputed denormal patterns
if (x < 2) {
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
bits = 0x00200000 << x;
}
// For x >= 2: normalized exponent adjustment
else {
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
bits = (uint32_t)(x - 1) << 23;
}
// Note: NaNs are not handled here
float result;
memcpy(&result, &bits, sizeof(float));
return result;
}
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
/**
* Converts brain16 to float32.
*
+14
View File
@@ -23,6 +23,9 @@
#define N_R0_Q8_0 4
#define N_SG_Q8_0 2
#define N_R0_MXFP4 2
#define N_SG_MXFP4 2
#define N_R0_Q2_K 4
#define N_SG_Q2_K 2
@@ -129,6 +132,15 @@ typedef struct {
uint64_t o1[8];
} ggml_metal_kargs_bin;
typedef struct {
int64_t ne0;
int64_t ne1;
size_t nb01;
size_t nb02;
size_t nb11;
size_t nb21;
} ggml_metal_kargs_add_id;
typedef struct {
int32_t ne00;
int32_t ne01;
@@ -444,6 +456,8 @@ typedef struct{
uint64_t nb1;
int32_t i00;
int32_t i10;
float alpha;
float limit;
} ggml_metal_kargs_glu;
typedef struct {
+109 -9
View File
@@ -195,6 +195,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
GGML_METAL_KERNEL_TYPE_DIV,
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
GGML_METAL_KERNEL_TYPE_ADD_ID,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -234,6 +235,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
@@ -286,6 +288,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -310,6 +313,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
@@ -351,6 +358,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
@@ -373,6 +381,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
@@ -397,6 +406,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
@@ -579,6 +589,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_REGLU,
GGML_METAL_KERNEL_TYPE_GEGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -1199,6 +1210,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
@@ -1238,6 +1250,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
@@ -1290,6 +1303,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1314,6 +1328,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
@@ -1355,6 +1373,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
@@ -1377,6 +1396,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
@@ -1401,6 +1422,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
@@ -1583,6 +1605,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -1774,6 +1797,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
@@ -1791,6 +1815,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
@@ -2042,6 +2067,7 @@ static int ggml_metal_encode_node(
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
size_t offs_src0 = 0;
@@ -2291,6 +2317,38 @@ static int ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
case GGML_OP_ADD_ID:
{
GGML_ASSERT(src0t == GGML_TYPE_F32);
GGML_ASSERT(src1t == GGML_TYPE_F32);
GGML_ASSERT(src2t == GGML_TYPE_I32);
GGML_ASSERT(dstt == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous_rows(src0));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
ggml_metal_kargs_add_id args = {
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb11 =*/ nb11,
/*.nb21 =*/ nb21,
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_REPEAT:
{
id<MTLComputePipelineState> pipeline;
@@ -2710,6 +2768,9 @@ static int ggml_metal_encode_node(
case GGML_GLU_OP_SWIGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
break;
case GGML_GLU_OP_SWIGLU_OAI:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
break;
case GGML_GLU_OP_GEGLU_ERF:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
break;
@@ -2720,7 +2781,9 @@ static int ggml_metal_encode_node(
GGML_ABORT("fatal error");
}
const int32_t swp = ((const int32_t *) dst->op_params)[1];
const int32_t swp = ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
const int32_t i00 = swp ? ne0 : 0;
const int32_t i10 = swp ? 0 : ne0;
@@ -2734,6 +2797,8 @@ static int ggml_metal_encode_node(
/*.nb1 =*/ nb1,
/*.i00 =*/ src1 ? 0 : i00,
/*.i10 =*/ src1 ? 0 : i10,
/*.alpha=*/ alpha,
/*.limit=*/ limit
};
[encoder setComputePipelineState:pipeline];
@@ -2992,8 +3057,13 @@ static int ggml_metal_encode_node(
} else {
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3];
if (id_src2) {
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
} else {
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:2];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBytes:&args length:sizeof(args) atIndex:4];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
@@ -3291,6 +3361,7 @@ static int ggml_metal_encode_node(
src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 ||
src0t == GGML_TYPE_Q8_0 ||
src0t == GGML_TYPE_MXFP4 ||
src0t == GGML_TYPE_IQ4_NL ||
false) && (ne11 >= 2 && ne11 <= 8)
) ||
@@ -3383,6 +3454,14 @@ static int ggml_metal_encode_node(
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_MXFP4:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_Q4_K:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
@@ -3481,6 +3560,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
@@ -3623,6 +3703,13 @@ static int ggml_metal_encode_node(
nr0 = N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
{
nsg = N_SG_MXFP4;
nr0 = N_R0_MXFP4;
smem = 32*sizeof(float);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
} break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;
@@ -3756,8 +3843,6 @@ static int ggml_metal_encode_node(
case GGML_OP_MUL_MAT_ID:
{
// src2 = ids
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
GGML_ASSERT(src2t == GGML_TYPE_I32);
GGML_ASSERT(!ggml_is_transposed(src0));
@@ -3883,6 +3968,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
@@ -4018,6 +4104,13 @@ static int ggml_metal_encode_node(
nr0 = N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
{
nsg = N_SG_MXFP4;
nr0 = N_R0_MXFP4;
smem = 32*sizeof(float);
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
} break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;
@@ -4170,6 +4263,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
@@ -4980,11 +5074,14 @@ static int ggml_metal_encode_node(
GGML_ASSERT(ne11 == ne21);
GGML_ASSERT(ne12 == ne22);
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src3 = node->src[3]; // mask
struct ggml_tensor * src4 = node->src[4]; // sinks
size_t offs_src3 = 0;
size_t offs_src4 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
@@ -5000,8 +5097,6 @@ static int ggml_metal_encode_node(
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
float scale;
float max_bias;
float logit_softcap;
@@ -5389,7 +5484,12 @@ static int ggml_metal_encode_node(
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
if (id_src4) {
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
if (!use_vec_kernel) {
// half8x8 kernel
+272 -15
View File
@@ -35,6 +35,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
constexpr constant static float kvalues_mxfp4_f[16] = {
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
};
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
@@ -46,6 +50,18 @@ static inline int best_index_int8(int n, constant float * val, float x) {
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static inline float e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = (uint32_t) x << 23;
}
return as_type<float>(bits);
}
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -242,6 +258,27 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
}
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
dst.qs[j] = round(x0);
}
}
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
@@ -462,25 +499,34 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
}
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
const float d = e8m0_to_fp32(xb->e);
const uint8_t shr = il >= 1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
}
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
template <typename type4>
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
dst.d = d;
const float d = e8m0_to_fp32(xb->e);
const short il4 = il%4;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
const uint8_t shr = il >= 4 ? 4 : 0;
dst.qs[j] = round(x0);
}
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
template <typename type4x4>
@@ -960,6 +1006,32 @@ kernel void kernel_div(
}
}
kernel void kernel_add_id(
constant ggml_metal_kargs_add_id & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i1 = tgpig.x;
const int i2 = tgpig.y;
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
const size_t nb1 = args.ne0 * sizeof(float);
const size_t nb2 = args.ne1 * nb1;
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
template<typename T>
kernel void kernel_repeat(
constant ggml_metal_kargs_repeat & args,
@@ -1431,6 +1503,32 @@ kernel void kernel_swiglu(
}
}
kernel void kernel_swiglu_oai(
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
float x0 = src0_row[i0];
float x1 = src1_row[i0];
x0 = min(x0, args.limit);
x1 = max(min(x1, args.limit), -args.limit);
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
out_glu = out_glu * (1.0f + x1);
dst_row[i0] = out_glu;
}
}
kernel void kernel_geglu_erf(
device const char * src0,
device const char * src1,
@@ -1534,6 +1632,7 @@ template<typename T>
kernel void kernel_soft_max(
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
@@ -1552,6 +1651,7 @@ kernel void kernel_soft_max(
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
@@ -1567,7 +1667,7 @@ kernel void kernel_soft_max(
}
// parallel max
float lmax = -INFINITY;
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
@@ -1623,6 +1723,10 @@ kernel void kernel_soft_max(
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
@@ -1634,6 +1738,7 @@ template<typename T>
kernel void kernel_soft_max_4(
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
@@ -1652,6 +1757,7 @@ kernel void kernel_soft_max_4(
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
@@ -1666,7 +1772,7 @@ kernel void kernel_soft_max_4(
}
// parallel max
float4 lmax4 = -INFINITY;
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
@@ -1725,6 +1831,10 @@ kernel void kernel_soft_max_4(
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
@@ -3106,6 +3216,11 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
@@ -4092,6 +4207,7 @@ kernel void kernel_flash_attn_ext(
device const char * k,
device const char * v,
device const char * mask,
device const char * sinks,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -4407,6 +4523,35 @@ kernel void kernel_flash_attn_ext(
}
}
if (sinks != q && sgitg == 0) {
for (ushort j = 0; j < Q; ++j) {
const float m = M[j];
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
M[j] = simd_max(max(M[j], s));
const float ms = exp(m - M[j]);
const float vs = exp(s - M[j]);
S[j] = S[j]*ms + simd_sum(vs);
if (tiisg == j) {
ss[j*TS + 2*C + j] = ms;
}
}
// O = diag(ms)*O
{
s8x8_t ms;
simdgroup_load(ms, ss + 2*C, TS, 0, false);
#pragma unroll(DV8)
for (short i = 0; i < DV8; ++i) {
simdgroup_multiply(lo[i], ms, lo[i]);
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = tiisg; j < Q; j += NW) {
ss[j*TS + 0] = S[j];
@@ -4618,6 +4763,7 @@ kernel void kernel_flash_attn_ext_vec(
device const char * k,
device const char * v,
device const char * mask,
device const char * sinks,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -4835,6 +4981,23 @@ kernel void kernel_flash_attn_ext_vec(
}
}
if (sinks != q && sgitg == 0) {
const float m = M;
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
M = simd_max(max(M, s));
const float ms = exp(m - M);
const float vs = exp(s - M);
S = S*ms + simd_sum(vs);
#pragma unroll(DV4/NL)
for (short ii = 0; ii < DV4; ii += NL) {
lo[ii/NL] *= ms;
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (tiisg == 0) {
ss[0] = (s_t) S;
@@ -6940,6 +7103,95 @@ kernel void kernel_mul_mv_iq4_xs_f32(
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int nr0, int nsg, int nw, typename args_t>
void kernel_mul_mv_mxfp4_f32_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int nb = args.ne00/QK_MXFP4;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f};
device const float * yb = y + ix * QK_MXFP4 + it * 8;
for (int ib = ix; ib < nb; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
#pragma unroll(nr0)
for (short row = 0; row < nr0; row++) {
device const block_mxfp4 & xb = x[row*nb + ib];
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
acc1 = (acc1 + acc3) + (acc2 + acc4);
sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
}
yb += 16 * QK_MXFP4;
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
}
}
}
[[host_name("kernel_mul_mv_mxfp4_f32")]]
kernel void kernel_mul_mv_mxfp4_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
@@ -7475,6 +7727,7 @@ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -7527,6 +7780,7 @@ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
@@ -7558,6 +7812,7 @@ template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
@@ -7703,6 +7958,8 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
+1
View File
@@ -55,6 +55,7 @@ endfunction()
set(GGML_OPENCL_KERNELS
add
add_id
argsort
clamp
cpy
+263 -96
View File
@@ -345,6 +345,7 @@ struct ggml_backend_opencl_context {
cl_command_queue queue;
cl_program program_add;
cl_program program_add_id;
cl_program program_clamp;
cl_program program_cpy;
cl_program program_cvt;
@@ -404,6 +405,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
cl_kernel kernel_add_id;
cl_kernel kernel_scale;
cl_kernel kernel_silu, kernel_silu_4;
cl_kernel kernel_gelu, kernel_gelu_4;
@@ -412,7 +414,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_relu;
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
cl_kernel kernel_norm;
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
@@ -600,6 +602,7 @@ struct ggml_backend_opencl_context {
if (ref_count == 0) {
#ifdef GGML_OPENCL_PROFILING
write_profiling_info();
profiling_info.clear();
#endif
}
}
@@ -681,6 +684,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// add_id
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "add_id.cl.h"
};
#else
const std::string kernel_src = read_file("add_id.cl");
#endif
backend_ctx->program_add_id =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_add_id = clCreateKernel(backend_ctx->program_add_id, "kernel_add_id", &err), err));
GGML_LOG_CONT(".");
}
// clamp
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -787,6 +806,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu_oai = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_oai", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
@@ -2461,12 +2481,21 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_OP_SCALE:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
if (op->type == GGML_TYPE_F16) {
const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;
const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;
if (src0_ok && src1_ok) {
return true;
}
}
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SUB:
return (op->src[0]->type == op->src[1]->type) &&
(op->src[0]->type == op->type) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
@@ -2488,6 +2517,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
@@ -2601,10 +2631,10 @@ ggml_backend_t ggml_backend_opencl_init(void) {
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_opencl_guid(),
/* .interface = */ ggml_backend_opencl_i,
/* .device = */ dev,
/* .context = */ backend_ctx
/* .guid = */ ggml_backend_opencl_guid(),
/* .iface = */ ggml_backend_opencl_i,
/* .device = */ dev,
/* .context = */ backend_ctx
};
return backend;
@@ -3694,34 +3724,30 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT(src0->type == dst->type);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = src1->ne[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3]; UNUSED(ne13);
const int ne10 = src1->ne[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3];
const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
const cl_ulong nb13 = src1->nb[3];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
@@ -3738,68 +3764,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
bool bcast_row = false;
cl_kernel kernel;
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;
if (bcast_row) {
GGML_ASSERT(ggml_is_contiguous(src0));
// src1 is a row
GGML_ASSERT(ne11 == 1);
}
bcast_row = true;
int ne = ne00 / 4;
if (src0->type == GGML_TYPE_F32) {
if (dst->type == GGML_TYPE_F32) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
if (bcast_row) {
kernel = backend_ctx->kernel_add_row;
const int ne = ne00 / 4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
} else {
kernel = backend_ctx->kernel_add_row_f16;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_add;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
}
} else if (dst->type == GGML_TYPE_F16) {
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
const int type_src0 = (src0->type == GGML_TYPE_F32);
const int type_src1 = (src1->type == GGML_TYPE_F32);
if (bcast_row) {
kernel = backend_ctx->kernel_add_row_f16;
const int ne = ne00 / 4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &type_src0));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &type_src1));
} else {
kernel = backend_ctx->kernel_add_f16;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &type_src0));
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int), &type_src1));
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
} else {
GGML_ASSERT(false && "unsupported data types for add");
}
if (bcast_row) {
@@ -3809,19 +3881,88 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr; // Let driver choose the work-group sizes.
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);
} else {
unsigned int nth = MIN(64, ne0);
size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
}
static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(src2);
GGML_ASSERT(src2->extra);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous_rows(src0));
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb21 = src2->nb[1];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offset2 = extra2->offset + src2->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel = backend_ctx->kernel_add_id;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));
size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };
size_t local_work_size[] = { (size_t)nth, 1, 1 };
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -6500,17 +6641,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_ASSERT(src1->extra);
}
const ggml_tensor * src2 = dst->src[2];
if (src2) {
GGML_ASSERT(src2->extra);
}
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
@@ -6578,25 +6726,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
@@ -7003,6 +7153,9 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
kernel = backend_ctx->kernel_swiglu_f16;
}
break;
case GGML_GLU_OP_SWIGLU_OAI:
kernel = backend_ctx->kernel_swiglu_oai;
break;
case GGML_GLU_OP_GEGLU_ERF:
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_geglu_erf;
@@ -7038,7 +7191,10 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
const cl_ulong nb1 = dst->nb[1];
const int swp = ((const int32_t *) dst->op_params)[1];
const int swp = ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);
const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
@@ -7055,6 +7211,11 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne00_off));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10_off));
if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) {
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &limit));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &alpha));
}
const size_t nrows = ggml_nrows(src0);
size_t nth = 512;
size_t global_work_size[] = {nrows*nth, 1, 1};
@@ -7111,6 +7272,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_add;
break;
case GGML_OP_ADD_ID:
if (!any_on_device) {
return false;
}
func = ggml_cl_add_id;
break;
case GGML_OP_MUL:
if (!any_on_device) {
return false;
+42 -8
View File
@@ -112,7 +112,9 @@ kernel void kernel_add_f16(
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
ulong nb3,
int type_src0,
int type_src1
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
@@ -132,25 +134,57 @@ kernel void kernel_add_f16(
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const int i10 = i0 % ne10;
*((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) + *((global half *)(src1_ptr + i10*nb10));
half v0, v1;
if (type_src0 == 1) {
v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
} else {
v0 = *((global half *)(src0_ptr + i0*nb00));
}
if (type_src1 == 1) {
v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
} else {
v1 = *((global half *)(src1_ptr + i10*nb10));
}
*((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
}
}
kernel void kernel_add_row_f16(
global half4 * src0,
global char * src0,
ulong offset0,
global half4 * src1,
global char * src1,
ulong offset1,
global half4 * dst,
ulong offsetd,
int ne
int ne,
int type_src0,
int type_src1
) {
src0 = (global half4*)((global char*)src0 + offset0);
src1 = (global half4*)((global char*)src1 + offset1);
dst = (global half4*)((global char*)dst + offsetd);
// This performs better than using %.
uint gid = get_global_id(0);
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
dst[gid] = src0[gid] + src1[idx1];
half4 v0, v1;
if (type_src0 == 1) {
global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
v0 = convert_half4(src0_f32[gid]);
} else {
global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
v0 = src0_f16[gid];
}
if (type_src1 == 1) {
global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
v1 = convert_half4(src1_f32[idx1]);
} else {
global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
v1 = src1_f16[idx1];
}
dst[gid] = v0 + v1;
}
+42
View File
@@ -0,0 +1,42 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// add_id
//------------------------------------------------------------------------------
kernel void kernel_add_id(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb02,
ulong nb11,
ulong nb21,
int ne0,
int ne1
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
src2 = (global char*)((global char*)src2 + offset2);
dst = (global char*)((global char*)dst + offsetd);
int i1 = get_group_id(0);
int i2 = get_group_id(1);
const int i11 = *((global const int *) (src2 + i1*sizeof(int) + i2*nb21));
const size_t nb1 = ne0 * sizeof(float);
const size_t nb2 = ne1 * nb1;
global float * dst_row = (global float *)((global char *)dst + i1*nb1 + i2*nb2);
global float * src0_row = (global float *)((global char *)src0 + i1*nb01 + i2*nb02);
global float * src1_row = (global float *)((global char *)src1 + i11*nb11);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
+41
View File
@@ -202,6 +202,47 @@ kernel void kernel_swiglu_f16(
}
}
//------------------------------------------------------------------------------
// swiglu_oai
//------------------------------------------------------------------------------
kernel void kernel_swiglu_oai(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off,
float limit,
float alpha
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
float x0 = src0_row[i0];
float x1 = src1_row[i0];
x0 = min(x0, limit);
x1 = max(min(x1, limit), -limit);
float out_glu = x0 / (1.0f + exp(-x0 * alpha));
out_glu = out_glu * (1.0f + x1);
dst_row[i0] = out_glu;
}
}
//------------------------------------------------------------------------------
// geglu_erf
//------------------------------------------------------------------------------
+10 -2
View File
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
src2 = src2 + offset2;
dst = dst + offsetd;
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
}
// parallel max
float4 lmax4 = -INFINITY;
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
}
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
}
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
const float sum = sub_group_reduce_add(lsum);
float sum = sub_group_reduce_add(lsum);
if (psrc2) {
sum += exp(psrc2[i02] - max);
}
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
pdst4[i00] /= sum;
+10 -2
View File
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4(
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4(
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
src2 = src2 + offset2;
dst = dst + offsetd;
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4(
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4(
}
// parallel max
float4 lmax4 = -INFINITY;
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4(
}
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
const float sum = sub_group_reduce_add(lsum);
float sum = sub_group_reduce_add(lsum);
if (psrc2) {
sum += exp(psrc2[i02] - max);
}
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
pdst4[i00] /= sum;
+10 -2
View File
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16(
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16(
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
src2 = src2 + offset2;
dst = dst + offsetd;
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16(
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16(
}
// parallel max
float lmax = -INFINITY;
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
@@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16(
pdst[i00] = exp_psrc0;
}
const float sum = sub_group_reduce_add(lsum);
float sum = sub_group_reduce_add(lsum);
if (psrc2) {
sum += exp(psrc2[i02] - max);
}
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
pdst[i00] /= sum;
+10 -2
View File
@@ -26,6 +26,8 @@ kernel void kernel_soft_max(
ulong offset0,
global char * src1,
ulong offset1,
global char * src2,
ulong offset2,
global char * dst,
ulong offsetd,
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max(
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
src2 = src2 + offset2;
dst = dst + offsetd;
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max(
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max(
}
// parallel max
float lmax = -INFINITY;
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
@@ -91,7 +95,11 @@ kernel void kernel_soft_max(
pdst[i00] = exp_psrc0;
}
const float sum = sub_group_reduce_add(lsum);
float sum = sub_group_reduce_add(lsum);
if (psrc2) {
sum += exp(psrc2[i02] - max);
}
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
pdst[i00] /= sum;
+105 -11
View File
@@ -21,6 +21,17 @@
#define UNUSED GGML_UNUSED
static inline int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
// reference implementation for deterministic creation of model files
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK4_0;
@@ -246,6 +257,53 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
}
}
static inline int best_index_mxfp4(float x, float e) {
int best_index = 0;
float best_err = fabsf(kvalues_mxfp4[0]*e - x);
for (int i = 1; i < 16; i++) {
float err = fabsf(kvalues_mxfp4[i]*e - x);
if (err < best_err) {
best_index = i;
best_err = err;
}
}
return best_index;
}
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_MXFP4;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
for (int j = 0; j < qk; j++) {
const float v = x[i*qk + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
}
}
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
const float d = GGML_E8M0_TO_FP32_HALF(e);
y[i].e = e;
for (int j = 0; j < qk/2; ++j) {
const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
y[i].qs[j] = x0;
y[i].qs[j] |= x1 << 4;
}
}
}
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK4_0;
@@ -356,6 +414,26 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
}
}
void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_MXFP4;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
for (int j = 0; j < qk/2; ++j) {
const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
//
// 2-6 bit quantization in super-blocks
//
@@ -2014,6 +2092,12 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
return nrow * row_size;
}
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
}
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@@ -4551,17 +4635,6 @@ size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
// ============================ 4-bit non-linear quants
static inline int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
float * scales, float * weight, uint8_t * L,
@@ -4961,6 +5034,15 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
return true;
}
static bool validate_e_e8m0(uint8_t e, size_t i) {
if (e == 0xff) {
fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
return false;
}
return true;
}
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
@@ -4977,6 +5059,14 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
} \
}
#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
if (!validate_e_e8m0(q[i].e, i)) { \
return false; \
} \
}
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
@@ -5130,6 +5220,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
} break;
case GGML_TYPE_MXFP4:
{
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
} break;
case GGML_TYPE_Q2_K:
{
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
+6
View File
@@ -21,6 +21,8 @@ GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 *
GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
@@ -45,6 +47,8 @@ GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GG
GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API void iq2xs_init_impl(enum ggml_type type);
GGML_API void iq2xs_free_impl(enum ggml_type type);
GGML_API void iq3xs_init_impl(int grid_size);
+4 -4
View File
@@ -823,10 +823,10 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
};
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(),
/* .interface = */ ggml_backend_rpc_interface,
/* .device = */ ggml_backend_rpc_add_device(endpoint),
/* .context = */ ctx
/* .guid = */ ggml_backend_rpc_guid(),
/* .iface = */ ggml_backend_rpc_interface,
/* .device = */ ggml_backend_rpc_add_device(endpoint),
/* .context = */ ctx
};
return backend;
}
+17 -15
View File
@@ -2609,6 +2609,8 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src1->ne[1] == 1);
GGML_ASSERT(src1->ne[3] == 1);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
@@ -3196,7 +3198,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
}
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) {
// KQV single-batch
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
@@ -4191,15 +4193,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
struct ggml_tensor * a;
struct ggml_tensor * b;
if (op->op == GGML_OP_MUL_MAT) {
a = op->src[0];
b = op->src[1];
} else {
a = op->src[2];
b = op->src[1];
}
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
if (a->ne[3] != b->ne[3]) {
return false;
}
@@ -4214,7 +4210,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
}
}
ggml_type src0_type = op->src[0]->type;
if (src0_type == GGML_TYPE_BF16) {
if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
// TODO: support MXFP4
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
return false;
}
return true;
@@ -4359,6 +4357,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[3] != 1) {
return false;
}
// TODO: support attention sinks [TAG_ATTN_SINKS]
if (op->src[2]) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
@@ -4584,10 +4586,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) {
};
ggml_backend_t sycl_backend = new ggml_backend {
/* .guid = */ ggml_backend_sycl_guid(),
/* .interface = */ ggml_backend_sycl_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
/* .context = */ ctx
/* .guid = */ ggml_backend_sycl_guid(),
/* .iface = */ ggml_backend_sycl_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
/* .context = */ ctx
};
return sycl_backend;
+177 -30
View File
@@ -449,6 +449,8 @@ struct vk_device_struct {
vk_pipeline pipeline_div[2][2][2];
vk_pipeline pipeline_div_norepeat[2][2][2];
vk_pipeline pipeline_add_id_f32;
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
vk_pipeline pipeline_scale_f32;
@@ -483,6 +485,7 @@ struct vk_device_struct {
vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2];
vk_pipeline pipeline_swiglu[2];
vk_pipeline pipeline_swiglu_oai[2];
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
@@ -531,6 +534,7 @@ struct vk_device_struct {
ggml_backend_buffer_type buffer_type;
bool disable_fusion;
bool disable_host_visible_vidmem;
#ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr<vk_memory_logger> memory_logger;
@@ -705,6 +709,8 @@ struct vk_op_glu_push_constants {
uint32_t ne00;
uint32_t ne20;
uint32_t mode; // 0: default, 1: swapped, 2: split
float alpha; // for swiglu_oai
float limit;
};
struct vk_op_unary_push_constants {
@@ -794,6 +800,15 @@ struct vk_op_binary_push_constants {
float param1; float param2; int32_t param3;
};
struct vk_op_add_id_push_constants {
uint32_t ne0;
uint32_t ne1;
uint32_t s01;
uint32_t s02;
uint32_t s11;
uint32_t s21;
};
struct vk_op_diag_mask_push_constants {
uint32_t ncols;
uint32_t rows_per_channel;
@@ -835,6 +850,7 @@ struct vk_op_soft_max_push_constants {
float m1;
uint32_t n_head_log2;
uint32_t nrows_x;
uint32_t has_sinks;
};
struct vk_op_argsort_push_constants {
@@ -1789,6 +1805,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
} else if (device->uma) {
// Fall back to host memory type
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
} else if (device->disable_host_visible_vidmem) {
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal);
} else {
// use rebar if available, otherwise fallback to device only visible memory
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -1977,6 +1995,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
break;
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_MXFP4:
lut_size = 4*16;
break;
default:
@@ -2267,14 +2286,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
};
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
@@ -2353,6 +2372,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
@@ -2379,6 +2399,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
@@ -2440,6 +2461,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2461,6 +2483,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2493,6 +2516,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2514,6 +2538,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM2
#undef CREATE_MM
@@ -2581,6 +2606,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2618,6 +2644,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
@@ -2672,6 +2699,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
@@ -2709,6 +2737,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
@@ -2767,6 +2796,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -2790,6 +2820,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -2814,6 +2845,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
// dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -2836,6 +2868,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
// get_rows
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2855,6 +2888,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2873,9 +2907,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -2976,6 +3011,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_BINARY(div, _norepeat, {1})
#undef CREATE_BINARY
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3026,6 +3063,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_GLU(geglu)
CREATE_GLU(reglu)
CREATE_GLU(swiglu)
CREATE_GLU(swiglu_oai)
CREATE_GLU(geglu_erf)
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
@@ -3035,10 +3073,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -3230,6 +3268,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM");
device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;
bool fp16_storage = false;
bool fp16_compute = false;
bool maintenance4_support = false;
@@ -4244,6 +4285,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4314,6 +4356,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4357,6 +4400,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4411,6 +4455,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4446,6 +4491,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
@@ -4631,6 +4677,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
@@ -6460,11 +6507,14 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
return supported;
}
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
if (sinks) {
std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
}
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
@@ -6663,10 +6713,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr;
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0;
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
@@ -6681,6 +6731,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
M_uma = d_M != nullptr;
}
if (sinks) {
ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
S_uma = d_S != nullptr;
}
}
@@ -6716,7 +6770,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
if (!S_uma) {
d_S = d_Q;
s_buf_offset = q_buf_offset;
if (sinks) {
ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
d_S = s_buf_ctx->dev_buffer;
s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
}
}
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
@@ -6740,6 +6804,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
},
// We only use split_k when group query attention is enabled, which means
@@ -6749,10 +6814,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
@@ -6763,6 +6829,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc, { workgroups_x, workgroups_y, workgroups_z });
@@ -6847,6 +6914,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
break;
}
return nullptr;
case GGML_OP_ADD_ID:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_add_id_f32;
}
return nullptr;
case GGML_OP_CONCAT:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_concat_f32;
@@ -6992,6 +7064,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_SWIGLU:
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_SWIGLU_OAI:
return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_ERF:
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_QUICK:
@@ -7007,6 +7081,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_SOFT_MAX:
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
@@ -7177,6 +7252,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SQR:
@@ -7523,6 +7599,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { ne, 1, 1 };
}
} break;
case GGML_OP_ADD_ID:
{
elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
} break;
case GGML_OP_SET_ROWS:
{
uint32_t ne = ggml_nelements(src0);
@@ -7562,8 +7642,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
// Empty src1 is possible in soft_max, but the shader needs a buffer
if (op == GGML_OP_GLU) {
// Empty src1 is possible in glu, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
subbuf_y = { d_Y, y_buf_offset, y_sz };
@@ -7573,6 +7653,24 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_SOFT_MAX) {
// Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
subbuf_y = { d_Y, y_buf_offset, y_sz };
} else {
subbuf_y = { d_X, 0, x_sz };
}
vk_subbuffer subbuf_z;
if (use_src2) {
subbuf_z = { d_Z, z_buf_offset, z_sz };
} else {
subbuf_z = { d_X, 0, x_sz };
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer
vk_subbuffer subbuf_z;
@@ -7701,6 +7799,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun);
}
static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t src2_type_size = ggml_type_size(src2->type);
ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
(uint32_t)dst->ne[0],
(uint32_t)dst->ne[1],
(uint32_t)src0->nb[1] / src0_type_size,
(uint32_t)src0->nb[2] / src0_type_size,
(uint32_t)src1->nb[1] / src1_type_size,
(uint32_t)src2->nb[1] / src2_type_size,
}, dryrun);
}
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
GGML_ASSERT(version == 6 || version == 7);
int num_srcs = version == 6 ? 6 : 7;
@@ -8119,8 +8232,12 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const float * op_params_f = (const float *)dst->op_params;
const bool swapped = (bool)dst->op_params[1];
const bool split = src1 != nullptr;
const float alpha = op_params_f[2];
const float limit = op_params_f[3];
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -8134,7 +8251,15 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
{
(uint32_t)ggml_nelements(dst),
(uint32_t)src0->ne[0],
(uint32_t)dst->ne[0],
mode,
alpha,
limit
}, dryrun);
}
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -8142,7 +8267,7 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
}
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
float scale = op_params[0];
@@ -8164,7 +8289,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -8174,6 +8299,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
m0, m1,
n_head_log2,
nrows_x,
src2 != nullptr
}, dryrun);
}
@@ -9413,6 +9539,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
break;
@@ -9424,6 +9551,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_REPEAT_BACK:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
@@ -9578,6 +9706,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_DIV:
ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_ADD_ID:
ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
break;
case GGML_OP_CONCAT:
ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9675,6 +9807,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9688,7 +9821,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_SOFT_MAX:
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
break;
case GGML_OP_SOFT_MAX_BACK:
@@ -9761,7 +9894,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun);
break;
@@ -9834,6 +9967,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
@@ -9903,6 +10037,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
buf = tensor->buffer;
@@ -10632,10 +10767,10 @@ ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
ggml_vk_init(ctx, dev_num);
ggml_backend_t vk_backend = new ggml_backend {
/* .guid = */ ggml_backend_vk_guid(),
/* .interface = */ ggml_backend_vk_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
/* .context = */ ctx,
/* .guid = */ ggml_backend_vk_guid(),
/* .iface = */ ggml_backend_vk_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
/* .context = */ ctx,
};
return vk_backend;
@@ -10752,6 +10887,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous(op->src[0]) &&
@@ -10797,6 +10933,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
break;
default:
return false;
@@ -10834,6 +10971,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
return false;
}
if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
return false;
}
if (op->src[0]->type != GGML_TYPE_F32) {
return false;
}
@@ -10906,6 +11046,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
return true;
default:
return false;
@@ -11004,6 +11145,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_SQR:
@@ -11422,6 +11566,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
if (src_clone[4]) {
ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
}
} else if (tensor->op == GGML_OP_MUL_MAT) {
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
@@ -0,0 +1,42 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.comp"
layout (push_constant) uniform parameter
{
uint ne0;
uint ne1;
uint s01;
uint s02;
uint s11;
uint s21;
} p;
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) readonly buffer Z {int32_t data_c[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i1 = gl_WorkGroupID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i11 = data_c[i1 + i2 * p.s21];
const uint s1 = p.ne0;
const uint s2 = p.ne0 * p.ne1;
const uint d0 = i1 * s1 + i2 * s2;
const uint a0 = i1 * p.s01 + i2 * p.s02;
const uint b0 = i11 * p.s11;
for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
}
}
@@ -4,8 +4,8 @@
#include "generic_unary_head.comp"
#include "dequant_funcs.comp"
#if defined(DATA_A_IQ4_NL)
// 16 invocations needed for init_iq4nl_shmem
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
// 16 invocations needed for init_iq_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
@@ -434,6 +434,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
vec2 v1 = dequantize(ib, iqs + 1, a_offset);
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
@@ -455,6 +467,12 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif
#if defined(DATA_A_MXFP4)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
@@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
}
#endif
#if defined(DATA_A_MXFP4)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
};
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float d = e8m0_to_fp32(bl.block.e);
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
return ret;
}
#endif
#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
@@ -696,4 +715,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#endif
@@ -0,0 +1,32 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
}
}
@@ -305,6 +305,27 @@ void main() {
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ms;
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
@@ -50,10 +50,13 @@ layout (push_constant) uniform parameter {
uint32_t k_num;
} p;
#define SINK_ENABLE_BIT (1<<24)
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
@@ -111,6 +114,14 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
// Load the sink value, indexed by Q's dimension 2.
ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
return ACC_TYPE(data_s[h]);
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
@@ -329,6 +329,27 @@ void main() {
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ACC_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
@@ -248,6 +248,34 @@ void main() {
// resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
// resize M by using smear/reduce
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
// O, Ldiag, Mr all have the same type so all element locations match
[[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
ACC_TYPE sink = S[i];
ACC_TYPE ms = ACC_TYPE(1.0f);
ACC_TYPE vs = ACC_TYPE(1.0f);
if (sink > Mr[i]) {
ms = exp(Mr[i] - sink);
O[i] *= ms;
} else {
vs = exp(sink - Mr[i]);
}
Ldiag[i] = Ldiag[i]*ms + vs;
}
}
[[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
@@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
layout (binding = 1) readonly buffer B {float data_s[];};
layout (binding = 2) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint ne3;
uint k_num;
uint sinks;
} p;
shared float tmpsh[BLOCK_SIZE];
@@ -73,6 +75,22 @@ void main() {
}
L = tmpsh[0];
float sink;
if (p.sinks != 0) {
sink = data_s[n];
float ms = 1.0f;
float vs = 1.0f;
if (sink > m_max) {
ms = exp(m_max - sink);
} else {
vs = exp(sink - m_max);
}
L = L*ms + vs;
}
L = 1.0 / L;
// D dimension is split across workgroups in the y dimension
@@ -85,6 +103,13 @@ void main() {
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
if (p.sinks != 0) {
if (sink > m_max) {
float ms = 1.0f;
ms = exp(m_max - sink);
O *= ms;
}
}
O *= L;
data_d[iq3 * D * N + D * n + d] = O;
}
@@ -14,4 +14,6 @@ layout (push_constant) uniform parameter
uint ne00;
uint ne20;
uint mode;
float alpha;
float limit;
} p;
@@ -747,6 +747,21 @@ void main() {
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
#elif defined(DATA_A_MXFP4)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
const float d = e8m0_to_fp32(data_a[ib].e);
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d);
buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d);
#endif
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
@@ -92,6 +92,12 @@ FLOAT_TYPE get_d(uint ib) {
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
@@ -20,6 +20,7 @@ layout (push_constant) uniform parameter
float m1;
uint n_head_log2;
uint nrows_x;
uint has_sinks;
} p;
#include "types.comp"
@@ -29,7 +30,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) buffer D {D_TYPE data_d[];};
layout (binding = 2) readonly buffer Z {float data_c[];};
layout (binding = 3) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE vals[BLOCK_SIZE];
@@ -60,13 +62,13 @@ void soft_max(uint num_iters) {
const uint h = (rowx / p.ne01) % p.ne02; // head index
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
slope = pow(base, exp);
}
// Find max
FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
// Cache values while we compute the max, so we don't need to read them
// again when we're ready to compute exp(x-max).
@@ -148,6 +150,10 @@ void soft_max(uint num_iters) {
}
sum = vals[0];
if (p.has_sinks != 0) {
sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
}
FLOAT_TYPE rcpdivisor = 1.0/sum;
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
@@ -0,0 +1,14 @@
#version 450
#include "glu_head.comp"
float op(float a, float b) {
float xi = min(a, p.limit);
float gi = max(min(b, p.limit), -p.limit);
float out_glu = xi / (1.0f + exp(-xi * p.alpha));
out_glu = out_glu * (1.0f + gi);
return out_glu;
}
#include "glu_main.comp"

Some files were not shown because too many files have changed in this diff Show More