mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-28 00:27:39 +02:00
Compare commits
113 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 25f1045f07 | |||
| 97669e4073 | |||
| 2f853687b3 | |||
| ef2af57ddf | |||
| 5d804a4938 | |||
| d4d8dbe383 | |||
| 35a42edac8 | |||
| fec7911f8f | |||
| 078ce23ea7 | |||
| a0c2b207c5 | |||
| 4b20d8b7e3 | |||
| 02c1813517 | |||
| 77dee9de97 | |||
| 4795c91c32 | |||
| b66df9d9c9 | |||
| b9382c3877 | |||
| 3dc7397a27 | |||
| e92d53b29e | |||
| 0d161f021a | |||
| 4efd5a8316 | |||
| 274966226f | |||
| 9777032dcc | |||
| 7d3c9f2b21 | |||
| bbbf5ecccb | |||
| c37052ab4d | |||
| 5c16b9c87d | |||
| b97c9edc59 | |||
| 94e82c7ead | |||
| 4d74393bcc | |||
| dd892555b0 | |||
| e81b8e4b7f | |||
| 38ad381f9f | |||
| 696fccf354 | |||
| ef476916bb | |||
| d82f6aa34a | |||
| 3d16b29c3b | |||
| 792b44f2ed | |||
| 81017865ee | |||
| 60e5eee31f | |||
| 009b709d6e | |||
| e8d99dd0b6 | |||
| a8bca68f72 | |||
| c97dc09391 | |||
| 6c442f42ff | |||
| 73804145ab | |||
| c8d0d14e77 | |||
| 84ab83cc0b | |||
| 55042b3692 | |||
| 8a4280ce43 | |||
| 64387f6e95 | |||
| d35a1e8c41 | |||
| 46d9caa27a | |||
| 5a0e3ef6f0 | |||
| fbef0fad7a | |||
| da54f9f1a2 | |||
| 47373271f9 | |||
| 1bded5a3b3 | |||
| 1e7489745a | |||
| 1cf123a343 | |||
| fcca2182a1 | |||
| 86076f92de | |||
| bcbddcd54f | |||
| 8b69686136 | |||
| 8ce3ff1d91 | |||
| 44b1efa41a | |||
| a6a58d6478 | |||
| 0373486dbc | |||
| 62cef26ac5 | |||
| 8f5afa94c4 | |||
| b3964c1e89 | |||
| 79a546220c | |||
| 85cc1ae998 | |||
| 1d8d83deaa | |||
| c4e9239064 | |||
| 39842a7f73 | |||
| 0fd90db585 | |||
| 4c37636b3e | |||
| 34bdbbd7c2 | |||
| 74f52f77f2 | |||
| f7207b0415 | |||
| 4d917cd4f6 | |||
| 886b97a5d6 | |||
| 111f8d06f0 | |||
| 5eff6ec9b1 | |||
| dfd9b5f6c7 | |||
| 5a6bc6b1a6 | |||
| 6b64f74b55 | |||
| 0d5a470223 | |||
| b0ba31f525 | |||
| 7da9fed0d6 | |||
| c247d06f38 | |||
| 043fb27d38 | |||
| b730706a49 | |||
| c9a24fb932 | |||
| a9c6ffcbfa | |||
| e78cf0d4b1 | |||
| 710dfc465a | |||
| 611f419cff | |||
| b1afcab804 | |||
| 9ef536907d | |||
| 21dc4ddaf2 | |||
| 289bf4113e | |||
| b55f06e1aa | |||
| 0a9b43e507 | |||
| 330c3d2d21 | |||
| e92734d51b | |||
| 45363632cb | |||
| 32732f2459 | |||
| 92f7f0a53c | |||
| b1ab91821f | |||
| 9ebebef62f | |||
| ad5c975c2d | |||
| 4afb0a746f |
@@ -2,14 +2,30 @@ ARG UBUNTU_VERSION=24.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
# Install build tools
|
||||
RUN apt update && apt install -y git build-essential cmake wget
|
||||
# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html
|
||||
|
||||
# Install Vulkan SDK and cURL
|
||||
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
|
||||
apt update -y && \
|
||||
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
|
||||
# Install build tools
|
||||
RUN apt update && apt install -y git build-essential cmake wget xz-utils
|
||||
|
||||
# Install Vulkan SDK
|
||||
ARG VULKAN_VERSION=1.4.321.1
|
||||
RUN ARCH=$(uname -m) && \
|
||||
wget -qO /tmp/vulkan-sdk.tar.xz https://sdk.lunarg.com/sdk/download/${VULKAN_VERSION}/linux/vulkan-sdk-linux-${ARCH}-${VULKAN_VERSION}.tar.xz && \
|
||||
mkdir -p /opt/vulkan && \
|
||||
tar -xf /tmp/vulkan-sdk.tar.xz -C /tmp --strip-components=1 && \
|
||||
mv /tmp/${ARCH}/* /opt/vulkan/ && \
|
||||
rm -rf /tmp/*
|
||||
|
||||
# Install cURL and Vulkan SDK dependencies
|
||||
RUN apt install -y libcurl4-openssl-dev curl \
|
||||
libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev
|
||||
|
||||
# Set environment variables
|
||||
ENV VULKAN_SDK=/opt/vulkan
|
||||
ENV PATH=$VULKAN_SDK/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=$VULKAN_SDK/lib:$LD_LIBRARY_PATH
|
||||
ENV CMAKE_PREFIX_PATH=$VULKAN_SDK:$CMAKE_PREFIX_PATH
|
||||
ENV PKG_CONFIG_PATH=$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH
|
||||
|
||||
# Build it
|
||||
WORKDIR /app
|
||||
|
||||
@@ -137,6 +137,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
|
||||
|
||||
#### Multimodal
|
||||
|
||||
|
||||
@@ -386,10 +386,10 @@ function gg_run_open_llama_7b_v2 {
|
||||
|
||||
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
|
||||
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
|
||||
function check_ppl {
|
||||
qnt="$1"
|
||||
@@ -520,8 +520,8 @@ function gg_run_pythia_1_4b {
|
||||
|
||||
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
|
||||
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
|
||||
function check_ppl {
|
||||
qnt="$1"
|
||||
@@ -651,10 +651,10 @@ function gg_run_pythia_2_8b {
|
||||
|
||||
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
|
||||
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
|
||||
|
||||
function check_ppl {
|
||||
qnt="$1"
|
||||
|
||||
+40
-28
@@ -1106,7 +1106,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
|
||||
printf("\"\n\n");
|
||||
|
||||
printf(" case \"$prev\" in\n");
|
||||
printf(" --model)\n");
|
||||
printf(" --model|-m)\n");
|
||||
printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
|
||||
printf(" return 0\n");
|
||||
printf(" ;;\n");
|
||||
@@ -1545,10 +1545,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
|
||||
add_opt(common_arg(
|
||||
{"-fa", "--flash-attn"},
|
||||
string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
|
||||
[](common_params & params) {
|
||||
params.flash_attn = true;
|
||||
{"-fa", "--flash-attn"}, "FA",
|
||||
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "on" || value == "enabled") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
||||
} else if (value == "off" || value == "disabled") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
} else if (value == "auto") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
} else {
|
||||
throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_FLASH_ATTN"));
|
||||
add_opt(common_arg(
|
||||
@@ -2254,9 +2262,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
|
||||
add_opt(common_arg(
|
||||
{"-dt", "--defrag-thold"}, "N",
|
||||
string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),
|
||||
string_format("KV cache defragmentation threshold (DEPRECATED)"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.defrag_thold = std::stof(value);
|
||||
GGML_UNUSED(params);
|
||||
GGML_UNUSED(value);
|
||||
LOG_WRN("DEPRECATED: --defrag-thold is deprecated and no longer necessary to specify\n");
|
||||
}
|
||||
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
|
||||
add_opt(common_arg(
|
||||
@@ -2553,7 +2563,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--lora"}, "FNAME",
|
||||
"path to LoRA adapter (can be repeated to use multiple adapters)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
|
||||
params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
|
||||
}
|
||||
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
|
||||
@@ -2561,7 +2571,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--lora-scaled"}, "FNAME", "SCALE",
|
||||
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
|
||||
[](common_params & params, const std::string & fname, const std::string & scale) {
|
||||
params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
|
||||
params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
|
||||
}
|
||||
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
|
||||
@@ -2952,13 +2962,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.endpoint_metrics = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS"));
|
||||
add_opt(common_arg(
|
||||
{"--slots"},
|
||||
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
|
||||
[](common_params & params) {
|
||||
params.endpoint_slots = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
|
||||
add_opt(common_arg(
|
||||
{"--props"},
|
||||
string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"),
|
||||
@@ -2966,6 +2969,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.endpoint_props = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS"));
|
||||
add_opt(common_arg(
|
||||
{"--slots"},
|
||||
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
|
||||
[](common_params & params) {
|
||||
params.endpoint_slots = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
|
||||
add_opt(common_arg(
|
||||
{"--no-slots"},
|
||||
"disables slots monitoring endpoint",
|
||||
@@ -3457,8 +3467,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
@@ -3473,8 +3481,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
@@ -3489,8 +3495,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
@@ -3506,10 +3510,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.n_gpu_layers = 99;
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
@@ -3525,10 +3526,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.n_gpu_layers = 99;
|
||||
params.port = 8012;
|
||||
params.n_gpu_layers = 99;
|
||||
params.flash_attn = true;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
params.n_cache_reuse = 256;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--fim-qwen-30b-default"},
|
||||
string_format("use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
params.n_ctx = 0;
|
||||
|
||||
+173
-2
@@ -622,6 +622,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
|
||||
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
@@ -1361,6 +1362,26 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
"<|end|>",
|
||||
};
|
||||
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
data.grammar_lazy = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
|
||||
auto not_end = builder.add_rule("not-end",
|
||||
"[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]");
|
||||
auto analysis = builder.add_rule("analysis",
|
||||
"\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\"");
|
||||
auto constraint = builder.add_rule("constraint", "\"<|constrain|>\"? [a-zA-Z0-9_-]+");
|
||||
auto final = builder.add_rule("final",
|
||||
"\"<|channel|>final\" ( \" \" " + constraint + " )? \"<|message|>\" " +
|
||||
builder.add_schema("response", schema)
|
||||
);
|
||||
|
||||
builder.add_rule("root", "( " + analysis + " \"<|start|>assistant\" )? " + final);
|
||||
});
|
||||
}
|
||||
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
@@ -2039,6 +2060,94 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags first - this handles the main reasoning content
|
||||
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse tool calls - Seed-OSS uses <seed:tool_call> format
|
||||
static const common_regex tool_call_begin_regex("<seed:tool_call>");
|
||||
static const common_regex tool_call_end_regex("</seed:tool_call>");
|
||||
static const common_regex function_regex("<function=([^>]+)>");
|
||||
static const common_regex param_regex("<parameter=([^>]+)>");
|
||||
|
||||
while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) {
|
||||
builder.consume_spaces(); // Consume whitespace after <seed:tool_call>
|
||||
|
||||
// Look for function call inside tool call, ignore any content before it
|
||||
if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) {
|
||||
auto function_name = builder.str(func_res->groups[1]);
|
||||
|
||||
// Parse Seed-OSS parameters <parameter=name>value</parameter>
|
||||
json args = json::object();
|
||||
// Parse all parameters
|
||||
while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) {
|
||||
// again, ignore noise around parameters
|
||||
auto param_name = builder.str(param_res->groups[1]);
|
||||
builder.move_to(param_res->groups[0].end);
|
||||
builder.consume_spaces(); // Consume whitespace after parameter
|
||||
auto savedPos = builder.pos();
|
||||
if (auto param_parse = builder.try_find_literal("</parameter>")) {
|
||||
auto param = param_parse->prelude;
|
||||
builder.move_to(savedPos);
|
||||
try {
|
||||
if (auto param_res = builder.try_consume_json()) {
|
||||
args[param_name] = param_res->json;
|
||||
} else {
|
||||
args[param_name] = param;
|
||||
}
|
||||
} catch (json::exception &) {
|
||||
args[param_name] = param;
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool parameter");
|
||||
}
|
||||
}
|
||||
// Look for closing function tag
|
||||
auto end_func = builder.try_find_literal("</function>");
|
||||
if (end_func) {
|
||||
builder.move_to(end_func->groups[0].end);
|
||||
builder.consume_spaces(); // Consume whitespace after </function>
|
||||
|
||||
// Add the tool call with parsed arguments, but only if we REALLY got the literal
|
||||
auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end);
|
||||
auto funlen = std::string("</function>").length();
|
||||
if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("</function>")) {
|
||||
if (!builder.add_tool_call(function_name, "", args.dump())) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
// Look for closing tool call tag
|
||||
if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) {
|
||||
builder.move_to(end_tool->groups[0].end);
|
||||
builder.consume_spaces(); // Consume trailing whitespace after tool call
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
// No function found - don't consume content here, let it be handled at the end
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Consume any remaining whitespace after all tool call processing
|
||||
builder.consume_spaces();
|
||||
auto remaining = builder.consume_rest();
|
||||
// If there's any non-whitespace content remaining, add it as content
|
||||
if (!string_strip(remaining).empty()) {
|
||||
builder.add_content(remaining);
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
@@ -2055,8 +2164,62 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_seed_oss(
|
||||
const common_chat_template & tmpl,
|
||||
templates_params & params,
|
||||
const common_chat_templates_inputs & inputs)
|
||||
{
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_SEED_OSS;
|
||||
if (string_ends_with(data.prompt, "<seed:think>")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
data.prompt += "</seed:think>";
|
||||
} else {
|
||||
data.thinking_forced_open = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.tools.is_array() && !params.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
// Create rule for Seed-OSS function call format
|
||||
std::string param_rules;
|
||||
if (parameters.contains("properties")) {
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
param_rules += "\"<parameter=" + key + ">\"" + builder.add_schema(name + "-arg-" + key, value) +
|
||||
"\"</parameter>\"";
|
||||
}
|
||||
}
|
||||
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
"\"<seed:tool_call>\" space \"<function=" + name + ">\" space " +
|
||||
param_rules +
|
||||
" \"</function>\" space \"</seed:tool_call>\""));
|
||||
});
|
||||
|
||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<seed:tool_call>" });
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<seed:think>", "</seed:think>", "<seed:tool_call>", "</seed:tool_call>",
|
||||
"<function=", "</function>", "<parameter=", "</parameter>",
|
||||
};
|
||||
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
});
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_templates_apply_jinja(
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
templates_params params;
|
||||
@@ -2121,10 +2284,15 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
}
|
||||
|
||||
// GPT-OSS
|
||||
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
|
||||
if (src.find("<|channel|>") != std::string::npos) {
|
||||
return common_chat_params_init_gpt_oss(tmpl, params);
|
||||
}
|
||||
|
||||
// Seed-OSS
|
||||
if (src.find("<seed:think>") != std::string::npos) {
|
||||
return common_chat_params_init_seed_oss(tmpl, params, inputs);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
@@ -2283,6 +2451,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_GPT_OSS:
|
||||
common_chat_parse_gpt_oss(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS:
|
||||
common_chat_parse_seed_oss(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
|
||||
@@ -111,6 +111,7 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_GRANITE,
|
||||
COMMON_CHAT_FORMAT_GPT_OSS,
|
||||
COMMON_CHAT_FORMAT_SEED_OSS,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
+10
-4
@@ -901,7 +901,8 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
return iparams;
|
||||
}
|
||||
|
||||
@@ -911,7 +912,8 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
}
|
||||
@@ -988,7 +990,12 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
return iparams;
|
||||
}
|
||||
|
||||
char buf[1024];
|
||||
la.ptr = lora.get();
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
|
||||
la.task_name = buf;
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||
la.prompt_prefix = buf;
|
||||
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
}
|
||||
|
||||
@@ -1152,11 +1159,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.attention_type = params.attention_type;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.flash_attn_type = params.flash_attn_type;
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
cparams.swa_full = params.swa_full;
|
||||
|
||||
+5
-3
@@ -34,6 +34,9 @@ struct common_adapter_lora_info {
|
||||
std::string path;
|
||||
float scale;
|
||||
|
||||
std::string task_name;
|
||||
std::string prompt_prefix;
|
||||
|
||||
struct llama_adapter_lora * ptr;
|
||||
};
|
||||
|
||||
@@ -288,7 +291,6 @@ struct common_params {
|
||||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||
float defrag_thold = 0.1f; // KV cache defragmentation threshold
|
||||
|
||||
// offload params
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
@@ -310,6 +312,7 @@ struct common_params {
|
||||
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
||||
enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
|
||||
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
@@ -373,7 +376,6 @@ struct common_params {
|
||||
bool multiline_input = false; // reverse the usage of `\`
|
||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||
bool flash_attn = false; // flash attention
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool ctx_shift = false; // context shift on infinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
@@ -442,7 +444,7 @@ struct common_params {
|
||||
|
||||
// "advanced" endpoints are disabled by default for better security
|
||||
bool webui = true;
|
||||
bool endpoint_slots = false;
|
||||
bool endpoint_slots = true;
|
||||
bool endpoint_props = false; // only control POST requests, not GET
|
||||
bool endpoint_metrics = false;
|
||||
|
||||
|
||||
+23
-2
@@ -426,8 +426,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||
|
||||
// helpers
|
||||
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
|
||||
return &gsmpl->cur_p;
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||
auto * res = &gsmpl->cur_p;
|
||||
|
||||
if (do_sort && !res->sorted) {
|
||||
// remember the selected token before sorting
|
||||
const llama_token id = res->data[res->selected].id;
|
||||
|
||||
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.p > b.p;
|
||||
});
|
||||
|
||||
// restore the selected token after sorting
|
||||
for (size_t i = 0; i < res->size; ++i) {
|
||||
if (res->data[i].id == id) {
|
||||
res->selected = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
res->sorted = true;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
|
||||
|
||||
+3
-1
@@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||
// helpers
|
||||
|
||||
// access the internal list of current candidate tokens
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
|
||||
// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
|
||||
// the .sorted flag of the result indicates whether the returned candidates are sorted
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
|
||||
|
||||
// get the last accepted token
|
||||
llama_token common_sampler_last(const struct common_sampler * gsmpl);
|
||||
|
||||
@@ -317,7 +317,7 @@ llama_tokens common_speculative_gen_draft(
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, 0, true);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl);
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
|
||||
+248
-83
@@ -72,6 +72,7 @@ class ModelBase:
|
||||
endianess: gguf.GGUFEndian
|
||||
use_temp_file: bool
|
||||
lazy: bool
|
||||
dry_run: bool
|
||||
part_names: list[str]
|
||||
is_safetensors: bool
|
||||
hparams: dict[str, Any]
|
||||
@@ -111,6 +112,7 @@ class ModelBase:
|
||||
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||
self.use_temp_file = use_temp_file
|
||||
self.lazy = not eager or (remote_hf_model_id is not None)
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
if remote_hf_model_id is not None:
|
||||
self.is_safetensors = True
|
||||
@@ -300,10 +302,6 @@ class ModelBase:
|
||||
# data = data_torch.squeeze().numpy()
|
||||
data = data_torch.numpy()
|
||||
|
||||
# if data ends up empty, it means data_torch was a scalar tensor -> restore
|
||||
if len(data.shape) == 0:
|
||||
data = data_torch.numpy()
|
||||
|
||||
n_dims = len(data.shape)
|
||||
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
@@ -1216,6 +1214,55 @@ class TextModel(ModelBase):
|
||||
raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
|
||||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
|
||||
def _set_vocab_interns1(self):
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab))
|
||||
assert max(vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token: str = reverse_vocab[i]
|
||||
if token in added_vocab:
|
||||
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
tokens.append(token)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab._set_special_token("bos", 151643)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
class MmprojModel(ModelBase):
|
||||
model_type = ModelType.MMPROJ
|
||||
@@ -2932,7 +2979,8 @@ class Qwen2Model(TextModel):
|
||||
if "language_model." in name:
|
||||
name = name.replace("language_model.", "") # for InternVL
|
||||
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower"):
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower") \
|
||||
or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
|
||||
# skip vision and audio tensors
|
||||
return []
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
@@ -3109,7 +3157,7 @@ class LLaDAModel(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Ernie4_5_ForCausalLM")
|
||||
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
|
||||
class Ernie4_5Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.ERNIE4_5
|
||||
|
||||
@@ -3604,6 +3652,19 @@ class Qwen2MoeModel(TextModel):
|
||||
class Qwen3Model(Qwen2Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN3
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
|
||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||
|
||||
def set_vocab(self):
|
||||
# deal with intern-s1-mini
|
||||
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
|
||||
self._set_vocab_interns1()
|
||||
return
|
||||
|
||||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3MoeForCausalLM")
|
||||
class Qwen3MoeModel(Qwen2MoeModel):
|
||||
@@ -3620,73 +3681,7 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||
self._set_vocab_interns1()
|
||||
return
|
||||
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def _set_vocab_interns1(self):
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
|
||||
vocab_size = self.hparams.get("vocab_size", len(vocab))
|
||||
assert max(vocab.values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token: str = reverse_vocab[i]
|
||||
if token in added_vocab:
|
||||
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
|
||||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
previous_token = token
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
if previous_token != token:
|
||||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
|
||||
|
||||
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
tokens.append(token)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
|
||||
additional_special_tokens = []
|
||||
if special_tokens_map_file.is_file():
|
||||
with open(special_tokens_map_file, encoding = 'utf-8') as f:
|
||||
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
|
||||
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
|
||||
if tokenizer_cfg_file.is_file():
|
||||
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
|
||||
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
|
||||
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
|
||||
for token in additional_special_tokens:
|
||||
if token in token2ids_map:
|
||||
special_vocab._set_special_token(token, token2ids_map[token])
|
||||
special_vocab._set_special_token('eos', 151645)
|
||||
special_vocab._set_special_token("bos", 151643)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
super().set_vocab()
|
||||
|
||||
|
||||
@ModelBase.register("GPT2LMHeadModel")
|
||||
@@ -4874,11 +4869,35 @@ class NeoBert(BertModel):
|
||||
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
|
||||
class XLMRobertaModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.BERT
|
||||
_lora_files = {}
|
||||
_lora_names = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
|
||||
hparams = kwargs.pop("hparams", None)
|
||||
if hparams is None:
|
||||
hparams = ModelBase.load_hparams(dir_model, False)
|
||||
|
||||
if lora_names := hparams.get("lora_adaptations"):
|
||||
self._lora_names = lora_names
|
||||
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
|
||||
|
||||
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
|
||||
self._xlmroberta_tokenizer_init()
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
if self._lora_names:
|
||||
for name in self._lora_names:
|
||||
fname = self.add_prefix_to_filename(self.fname_out, f"lora-{name}-")
|
||||
self._lora_files[name] = gguf.GGUFWriter(fname, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, dry_run=self.dry_run)
|
||||
|
||||
return super().generate_extra_tensors()
|
||||
|
||||
def set_type(self):
|
||||
for lora_writer in self._lora_files.values():
|
||||
lora_writer.add_type(gguf.GGUFType.ADAPTER)
|
||||
lora_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
|
||||
super().set_type()
|
||||
|
||||
def set_vocab(self):
|
||||
self._xlmroberta_set_vocab()
|
||||
|
||||
@@ -4888,13 +4907,62 @@ class XLMRobertaModel(BertModel):
|
||||
if name.startswith("roberta."):
|
||||
name = name[8:]
|
||||
|
||||
# jina-embeddings-v3
|
||||
if ".parametrizations." in name:
|
||||
name = name.replace(".parametrizations.", ".")
|
||||
if name.endswith(".original"):
|
||||
name = name[:-9]
|
||||
|
||||
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
|
||||
if name == "embeddings.position_embeddings.weight":
|
||||
if self._position_offset is not None:
|
||||
data_torch = data_torch[self._position_offset:,:]
|
||||
|
||||
if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"):
|
||||
if name.startswith("pooler.dense"):
|
||||
return []
|
||||
|
||||
num_loras = data_torch.size(0)
|
||||
assert num_loras == len(self._lora_names)
|
||||
|
||||
# Split out each LoRA in their own GGUF
|
||||
for i, lora_writer in enumerate(self._lora_files.values()):
|
||||
new_name = self.map_tensor_name(name[:-9]) + name[-7:].lower()
|
||||
data = data_torch[i, :, :]
|
||||
# Transpose/flip token_embd/types into correct shape
|
||||
if new_name == "token_embd.weight.lora_b":
|
||||
data = data.T
|
||||
elif new_name.startswith("token_types.weight."):
|
||||
new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b")
|
||||
lora_writer.add_tensor(new_name, data.float().numpy(), raw_dtype=gguf.GGMLQuantizationType.F32)
|
||||
|
||||
return []
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# jina-embeddings-v3
|
||||
if rotary_emb_base := self.hparams.get("rotary_emb_base"):
|
||||
self.gguf_writer.add_rope_freq_base(rotary_emb_base)
|
||||
lora_alpha = self.hparams.get("lora_alpha")
|
||||
if lora_prompt_prefixes := self.hparams.get("task_instructions"):
|
||||
assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys())
|
||||
for lora_name, lora_writer in self._lora_files.items():
|
||||
lora_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha if lora_alpha is not None else 1.0)
|
||||
lora_writer.add_string(gguf.Keys.Adapter.LORA_TASK_NAME, lora_name)
|
||||
if lora_prompt_prefixes:
|
||||
lora_writer.add_string(gguf.Keys.Adapter.LORA_PROMPT_PREFIX, lora_prompt_prefixes[lora_name])
|
||||
|
||||
def write(self):
|
||||
super().write()
|
||||
for lora_writer in self._lora_files.values():
|
||||
lora_writer.write_header_to_file()
|
||||
lora_writer.write_kv_data_to_file()
|
||||
lora_writer.write_tensors_to_file(progress=True)
|
||||
lora_writer.close()
|
||||
|
||||
|
||||
@ModelBase.register("GemmaForCausalLM")
|
||||
class GemmaModel(TextModel):
|
||||
@@ -5854,6 +5922,11 @@ class OlmoModel(TextModel):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("SeedOssForCausalLM")
|
||||
class SeedOssModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.SEED_OSS
|
||||
|
||||
|
||||
@ModelBase.register("Olmo2ForCausalLM")
|
||||
class Olmo2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.OLMO2
|
||||
@@ -6252,9 +6325,11 @@ class DeepseekModel(TextModel):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("DeepseekV2ForCausalLM")
|
||||
@ModelBase.register("DeepseekV3ForCausalLM")
|
||||
@ModelBase.register("KimiVLForConditionalGeneration")
|
||||
@ModelBase.register(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"KimiVLForConditionalGeneration",
|
||||
)
|
||||
class DeepseekV2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
|
||||
@@ -7467,9 +7542,13 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
]
|
||||
|
||||
# n_group and d_inner are used during reshape_tensors for mamba2
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model"])
|
||||
self.n_group = self.find_hparam(["n_groups"])
|
||||
self.d_inner = self.find_hparam(["expand"]) * self.d_model
|
||||
# NOTE: Explicitly include hparam prefix prefix for d_model to
|
||||
# disambiguate with top-level head_dim
|
||||
# NOTE 2: If needed for future models, this can be isolated in a method
|
||||
# to separate the prefix setting and teh keys used
|
||||
self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
|
||||
self.n_group = self.find_hparam(["n_groups", "num_groups"])
|
||||
self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model
|
||||
|
||||
def get_attn_layers(self):
|
||||
# Explicit list of layer type names
|
||||
@@ -7530,12 +7609,12 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
|
||||
## Mamba mixer params ##
|
||||
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
|
||||
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
|
||||
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
|
||||
self.gguf_writer.add_ssm_group_count(self.n_group)
|
||||
self.gguf_writer.add_ssm_inner_size(self.d_inner)
|
||||
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
|
||||
# in llama.cpp
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))
|
||||
|
||||
## Attention params ##
|
||||
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
@@ -7562,6 +7641,55 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
Mamba2Model.set_vocab(self)
|
||||
|
||||
|
||||
@ModelBase.register("NemotronHForCausalLM")
|
||||
class NemotronHModel(GraniteHybridModel):
|
||||
"""Hybrid mamba2/attention model from NVIDIA"""
|
||||
model_arch = gguf.MODEL_ARCH.NEMOTRON_H
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Save the top-level head_dim for later
|
||||
self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
|
||||
assert self.head_dim is not None, "Could not find the attention head dim in config"
|
||||
|
||||
# Don't use expand to calculate d_inner
|
||||
self.d_inner = self.find_hparam(["num_heads"]) * self.d_model
|
||||
|
||||
# Update the ssm / attn / mlp layers
|
||||
# M: Mamba2, *: Attention, -: MLP
|
||||
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
|
||||
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
|
||||
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
|
||||
|
||||
def get_attn_layers(self):
|
||||
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
|
||||
assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
|
||||
return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_key_length(self.head_dim)
|
||||
self.gguf_writer.add_value_length(self.head_dim)
|
||||
|
||||
# Set feed_forward_length
|
||||
# NOTE: This will trigger an override warning. This is preferrable to
|
||||
# duplicating all the parent logic
|
||||
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
|
||||
self.gguf_writer.add_feed_forward_length([
|
||||
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
|
||||
])
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
|
||||
# The tokenizer _does_ add a BOS token (via post_processor type
|
||||
# TemplateProcessing) but does not set add_bos_token to true in the
|
||||
# config, so we need to explicitly override it here.
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
|
||||
@ModelBase.register("BailingMoeForCausalLM")
|
||||
class BailingMoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE
|
||||
@@ -8505,6 +8633,43 @@ class PixtralModel(LlavaVisionModel):
|
||||
return "mm.2.weight"
|
||||
return super().map_tensor_name(name, try_suffixes)
|
||||
|
||||
|
||||
@ModelBase.register("KimiVLForConditionalGeneration")
|
||||
class KimiVLModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams_vision is not None
|
||||
self.hparams_vision["image_size"] = 64 * 14 # for compatibility
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
self.gguf_writer.add_vision_projector_scale_factor(2)
|
||||
# eps is the same as pytorch's default value
|
||||
assert self.hparams_vision is not None
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
|
||||
|
||||
if is_vision_tensor:
|
||||
if "pos_emb.weight" in name:
|
||||
data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
|
||||
elif "wqkv" in name:
|
||||
split_dim = 0 if "weight" in name else -1
|
||||
wq, wk, wv = data_torch.chunk(3, dim=split_dim)
|
||||
return [
|
||||
(self.map_tensor_name(name.replace("wqkv", "wq")), wq),
|
||||
(self.map_tensor_name(name.replace("wqkv", "wk")), wk),
|
||||
(self.map_tensor_name(name.replace("wqkv", "wv")), wv)
|
||||
]
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
return [] # skip other tensors
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
||||
@@ -314,3 +314,7 @@ Controls automatic cleanup of the memory pool. This option is only effective whe
|
||||
|
||||
Converting the matmul weight format from ND to NZ can significantly improve performance on the 310I DUO NPU.
|
||||
|
||||
### GGML_CANN_DISABLE_ACL_GRAPH
|
||||
|
||||
When this variable is set, ACL graph execution is disabled and operators are executed in an op-by-op (eager) mode.
|
||||
This mode is mainly intended for debugging or for cases where the overhead of graph construction and execution is not desirable.
|
||||
|
||||
+4
-3
@@ -265,8 +265,9 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
| BF16 | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| Q4_0 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q4_1 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q5_0 | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| Q5_1 | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| MXFP4 | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| Q5_0 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q5_1 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q8_0 | ✅ | ✅ | ❓ | ❓ |
|
||||
| Q2_K | 🚫 | 🚫 | ❓ | ❓ |
|
||||
| Q3_K | ✅ | ✅ | ❓ | ❓ |
|
||||
@@ -291,4 +292,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
||||
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
||||
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on July 31, 2025.
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Aug 22, 2025.
|
||||
|
||||
@@ -59,8 +59,6 @@ cmake --build build --config Release
|
||||
cmake --preset arm64-windows-llvm-release -D GGML_OPENMP=OFF
|
||||
cmake --build build-arm64-windows-llvm-release
|
||||
```
|
||||
Building for arm64 can also be done with the MSVC compiler with the build-arm64-windows-MSVC preset, or the standard CMake build instructions. However, note that the MSVC compiler does not support inline ARM assembly code, used e.g. for the accelerated Q4_0_N_M CPU kernels.
|
||||
|
||||
For building with ninja generator and clang compiler as default:
|
||||
-set path:set LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\x64;C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.41.34120\lib\x64\uwp;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\x64
|
||||
```bash
|
||||
|
||||
@@ -21,6 +21,8 @@ Function calling is supported for all models (see https://github.com/ggml-org/ll
|
||||
- Use `--chat-template-file` to override the template when appropriate (see examples below)
|
||||
- Generic support may consume more tokens and be less efficient than a model's native format.
|
||||
|
||||
- Multiple/parallel tool calling is supported on some models but disabled by default, enable it by passing `"parallel_tool_calls": true` in the completion endpoint payload.
|
||||
|
||||
<details>
|
||||
<summary>Show some common templates and which format handler they use</summary>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model
|
||||
|
||||
|
||||
### Build llama.cpp
|
||||
Readme modification time: 20250206
|
||||
Readme modification time: 20250731
|
||||
|
||||
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
## MiniCPM-V 4.5
|
||||
|
||||
### Prepare models and code
|
||||
|
||||
Download [MiniCPM-V-4_5](https://huggingface.co/openbmb/MiniCPM-V-4_5) PyTorch model from huggingface to "MiniCPM-V-4_5" folder.
|
||||
|
||||
|
||||
### Build llama.cpp
|
||||
Readme modification time: 20250826
|
||||
|
||||
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md)
|
||||
|
||||
Clone llama.cpp:
|
||||
```bash
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
Build llama.cpp using `CMake`:
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
|
||||
### Usage of MiniCPM-V 4
|
||||
|
||||
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-4_5-gguf) by us)
|
||||
|
||||
```bash
|
||||
python ./tools/mtmd/legacy-models/minicpmv-surgery.py -m ../MiniCPM-V-4_5
|
||||
python ./tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-4_5 --minicpmv-projector ../MiniCPM-V-4_5/minicpmv.projector --output-dir ../MiniCPM-V-4_5/ --minicpmv_version 6
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-V-4_5/model
|
||||
|
||||
# quantize int4 version
|
||||
./build/bin/llama-quantize ../MiniCPM-V-4_5/model/ggml-model-f16.gguf ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
```
|
||||
|
||||
|
||||
Inference on Linux or Mac
|
||||
```bash
|
||||
# run in single-turn mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run in conversation mode
|
||||
./build/bin/llama-mtmd-cli -m ../MiniCPM-V-4_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-4_5/mmproj-model-f16.gguf
|
||||
```
|
||||
@@ -564,7 +564,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_params.n_ctx = params.n_ctx;
|
||||
ctx_params.n_batch = params.n_batch;
|
||||
ctx_params.n_ubatch = params.n_ubatch;
|
||||
ctx_params.flash_attn = params.flash_attn;
|
||||
ctx_params.flash_attn_type = params.flash_attn_type;
|
||||
ctx_params.no_perf = params.no_perf;
|
||||
ctx_params.type_k = params.cache_type_k;
|
||||
ctx_params.type_v = params.cache_type_v;
|
||||
|
||||
@@ -28,9 +28,40 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
|
||||
return str;
|
||||
}
|
||||
|
||||
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
|
||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||
float v;
|
||||
if (type == GGML_TYPE_F16) {
|
||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||
} else if (type == GGML_TYPE_F32) {
|
||||
v = *(float *) &data[i];
|
||||
} else if (type == GGML_TYPE_I64) {
|
||||
v = (float) *(int64_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I32) {
|
||||
v = (float) *(int32_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I16) {
|
||||
v = (float) *(int16_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I8) {
|
||||
v = (float) *(int8_t *) &data[i];
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
|
||||
GGML_ASSERT(n > 0);
|
||||
float sum = 0;
|
||||
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
||||
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
|
||||
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
|
||||
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
|
||||
sum += v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
|
||||
LOG(" [\n");
|
||||
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
|
||||
@@ -50,25 +81,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
|
||||
LOG("..., ");
|
||||
i0 = ne[0] - n;
|
||||
}
|
||||
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
|
||||
float v;
|
||||
if (type == GGML_TYPE_F16) {
|
||||
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
|
||||
} else if (type == GGML_TYPE_F32) {
|
||||
v = *(float *) &data[i];
|
||||
} else if (type == GGML_TYPE_I64) {
|
||||
v = (float) *(int64_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I32) {
|
||||
v = (float) *(int32_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I16) {
|
||||
v = (float) *(int16_t *) &data[i];
|
||||
} else if (type == GGML_TYPE_I8) {
|
||||
v = (float) *(int8_t *) &data[i];
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
|
||||
LOG("%12.4f", v);
|
||||
sum += v;
|
||||
if (i0 < ne[0] - 1) LOG(", ");
|
||||
}
|
||||
LOG("],\n");
|
||||
|
||||
+1
-1
@@ -17,7 +17,7 @@
|
||||
"
|
||||
" start the llama.cpp server with a FIM-compatible model. for example:
|
||||
"
|
||||
" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256
|
||||
" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa --ubatch-size 512 --batch-size 1024 --cache-reuse 256
|
||||
"
|
||||
" --batch-size [512, model max context]
|
||||
"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Validation functions
|
||||
MAKEFLAGS += --no-print-directory
|
||||
|
||||
define validate_model_path
|
||||
@if [ -z "$(MODEL_PATH)" ]; then \
|
||||
echo "Error: MODEL_PATH must be provided either as:"; \
|
||||
@@ -17,6 +18,13 @@ define validate_embedding_model_path
|
||||
fi
|
||||
endef
|
||||
|
||||
define quantize_model
|
||||
@CONVERTED_MODEL="$(1)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" \
|
||||
TOKEN_EMBD_TYPE="$(TOKEN_EMBD_TYPE)" OUTPUT_TYPE="$(OUTPUT_TYPE)" \
|
||||
./scripts/utils/quantize.sh "$(1)" "$(QUANTIZED_TYPE)" "$(TOKEN_EMBD_TYPE)" "$(OUTPUT_TYPE)"
|
||||
@echo "Export the quantized model path to $(2) variable in your environment"
|
||||
endef
|
||||
|
||||
###
|
||||
### Casual Model targets/recipes
|
||||
###
|
||||
@@ -29,6 +37,20 @@ causal-convert-model:
|
||||
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
|
||||
./scripts/causal/convert-model.sh
|
||||
|
||||
causal-convert-mm-model-bf16: OUTTYPE=bf16
|
||||
causal-convert-mm-model-bf16: MM_OUTTYPE=f16
|
||||
causal-convert-mm-model-bf16: causal-convert-mm-model
|
||||
|
||||
causal-convert-mm-model:
|
||||
$(call validate_model_path,causal-convert-mm-model)
|
||||
@MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \
|
||||
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
|
||||
./scripts/causal/convert-model.sh
|
||||
|
||||
@MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(MM_OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \
|
||||
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
|
||||
./scripts/causal/convert-model.sh --mmproj
|
||||
|
||||
causal-run-original-model:
|
||||
$(call validate_model_path,causal-run-original-model)
|
||||
@MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py
|
||||
@@ -67,9 +89,15 @@ causal-quantize-Q8_0: causal-quantize-model
|
||||
causal-quantize-Q4_0: QUANTIZED_TYPE = Q4_0
|
||||
causal-quantize-Q4_0: causal-quantize-model
|
||||
|
||||
# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the
|
||||
# token embedding and output types to Q8_0 instead of the default Q6_K.
|
||||
causal-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0
|
||||
causal-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0
|
||||
causal-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0
|
||||
causal-quantize-qat-Q4_0: causal-quantize-model
|
||||
|
||||
causal-quantize-model:
|
||||
@CONVERTED_MODEL="$(CONVERTED_MODEL)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" ./scripts/utils/quantize.sh ${CONVERTED_MODEL} ${QUANTIZED_TYPE}
|
||||
@echo "Export the quantized model path to QUANTIZED_MODEL variable in your environment"
|
||||
$(call quantize_model,$(CONVERTED_MODEL),QUANTIZED_MODEL)
|
||||
|
||||
causal-run-quantized-model:
|
||||
@QUANTIZED_MODEL="$(QUANTIZED_MODEL)" ./scripts/causal/run-converted-model.sh ${QUANTIZED_MODEL}
|
||||
@@ -117,9 +145,15 @@ embedding-quantize-Q8_0: embedding-quantize-model
|
||||
embedding-quantize-Q4_0: QUANTIZED_TYPE = Q4_0
|
||||
embedding-quantize-Q4_0: embedding-quantize-model
|
||||
|
||||
# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the
|
||||
# token embedding and output types to Q8_0 instead of the default Q6_K.
|
||||
embedding-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0
|
||||
embedding-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0
|
||||
embedding-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0
|
||||
embedding-quantize-qat-Q4_0: embedding-quantize-model
|
||||
|
||||
embedding-quantize-model:
|
||||
@./scripts/utils/quantize.sh ${CONVERTED_EMBEDDING_MODEL} ${QUANTIZED_TYPE}
|
||||
@echo "Export the quantized model path to QUANTIZED_EMBEDDING_MODEL variable in your environment"
|
||||
$(call quantize_model,$(CONVERTED_EMBEDDING_MODEL),QUANTIZED_EMBEDDING_MODEL)
|
||||
|
||||
embedding-run-quantized-model:
|
||||
@./scripts/embedding/run-converted-model.sh ${QUANTIZED_EMBEDDING_MODEL}
|
||||
@@ -144,6 +178,15 @@ perplexity-run:
|
||||
hf-create-model:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}"
|
||||
|
||||
hf-create-model-dry-run:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -d
|
||||
|
||||
hf-create-model-embedding:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e
|
||||
|
||||
hf-create-model-embedding-dry-run:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e -d
|
||||
|
||||
hf-create-model-private:
|
||||
@./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -p
|
||||
|
||||
|
||||
@@ -137,6 +137,18 @@ Then the quantized model can be run using the following command:
|
||||
(venv) $ make causal-run-quantized-model
|
||||
```
|
||||
|
||||
### Quantizing QAT (Quantization Aware Training) models
|
||||
When quantizing to `Q4_0`, the default data type for the token embedding weights
|
||||
will be `Q6_K`. For models that are going to be uploaded to ggml-org it is
|
||||
recommended to use `Q8_0` instead for the embeddings and output tensors.
|
||||
The reason is that although `Q6_K` is smaller in size, it requires more compute
|
||||
to unpack, which can hurt performance during output generation when the entire
|
||||
embedding matrix must be dequantized to compute vocabulary logits. `Q8_0`
|
||||
provides practically full quality with better computational efficiency.
|
||||
```console
|
||||
(venv) $ make causal-quantize-qat-Q4_0
|
||||
```
|
||||
|
||||
|
||||
## Embedding Language Model Conversion
|
||||
|
||||
@@ -238,6 +250,18 @@ Then the quantized model can be run using the following command:
|
||||
(venv) $ make embedding-run-quantized-model
|
||||
```
|
||||
|
||||
### Quantizing QAT (Quantization Aware Training) models
|
||||
When quantizing to `Q4_0`, the default data type for the token embedding weights
|
||||
will be `Q6_K`. For models that are going to be uploaded to ggml-org it is
|
||||
recommended to use `Q8_0` instead for the embeddings and output tensors.
|
||||
The reason is that although `Q6_K` is smaller in size, it requires more compute
|
||||
to unpack, which can hurt performance during output generation when the entire
|
||||
embedding matrix must be dequantized to compute vocabulary logits. `Q8_0`
|
||||
provides practically full quality with better computational efficiency.
|
||||
```console
|
||||
(venv) $ make embedding-quantize-qat-Q4_0
|
||||
```
|
||||
|
||||
## Perplexity Evaluation
|
||||
|
||||
### Simple perplexity evaluation
|
||||
@@ -285,13 +309,21 @@ For the following targets a `HF_TOKEN` environment variable is required.
|
||||
This will create a new model repsository on Hugging Face with the specified
|
||||
model name.
|
||||
```console
|
||||
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev"
|
||||
(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
|
||||
Repository ID: danbev/TestModel-GGUF
|
||||
Repository created: https://huggingface.co/danbev/TestModel-GGUF
|
||||
```
|
||||
Note that we append a `-GGUF` suffix to the model name to ensure a consistent
|
||||
naming convention for GGUF models.
|
||||
|
||||
An embedding model can be created using the following command:
|
||||
```console
|
||||
(venv) $ make hf-create-model-embedding MODEL_NAME='TestEmbeddingModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model"
|
||||
```
|
||||
The only difference is that the model card for an embedding model will be different
|
||||
with regards to the llama-server command and also how to access/call the embedding
|
||||
endpoint.
|
||||
|
||||
### Upload a GGUF model to model repository
|
||||
The following target uploads a model to an existing Hugging Face model repository.
|
||||
```console
|
||||
|
||||
@@ -112,6 +112,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_params.no_perf = false;
|
||||
if (embedding_mode) {
|
||||
ctx_params.embeddings = true;
|
||||
ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
ctx_params.n_ubatch = ctx_params.n_batch;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Parse command line arguments
|
||||
MMPROJ=""
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--mmproj)
|
||||
MMPROJ="--mmproj"
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}"
|
||||
OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
|
||||
TYPE="${OUTTYPE:-f16}"
|
||||
@@ -11,12 +27,20 @@ echo "Model name: ${MODEL_NAME}"
|
||||
echo "Data type: ${TYPE}"
|
||||
echo "Converted model path:: ${CONVERTED_MODEL}"
|
||||
echo "Metadata override: ${METADATA_OVERRIDE}"
|
||||
python ../../convert_hf_to_gguf.py --verbose \
|
||||
${MODEL_PATH} \
|
||||
--outfile ${CONVERTED_MODEL} \
|
||||
--outtype ${TYPE} \
|
||||
--metadata "${METADATA_OVERRIDE}"
|
||||
|
||||
CMD_ARGS=("python" "../../convert_hf_to_gguf.py" "--verbose")
|
||||
CMD_ARGS+=("${MODEL_PATH}")
|
||||
CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}")
|
||||
CMD_ARGS+=("--outtype" "${TYPE}")
|
||||
[[ -n "$METADATA_OVERRIDE" ]] && CMD_ARGS+=("--metadata" "${METADATA_OVERRIDE}")
|
||||
[[ -n "$MMPROJ" ]] && CMD_ARGS+=("${MMPROJ}")
|
||||
|
||||
"${CMD_ARGS[@]}"
|
||||
|
||||
echo ""
|
||||
echo "The environment variable CONVERTED_MODEL can be set to this path using:"
|
||||
echo "export CONVERTED_MODEL=$(realpath ${CONVERTED_MODEL})"
|
||||
if [[ -n "$MMPROJ" ]]; then
|
||||
mmproj_file="${OUTPUT_DIR}/mmproj-$(basename "${CONVERTED_MODEL}")"
|
||||
echo "The mmproj model was created in $(realpath "$mmproj_file")"
|
||||
fi
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
---
|
||||
base_model:
|
||||
- {base_model}
|
||||
---
|
||||
# {model_name} GGUF
|
||||
|
||||
Recommended way to run this model:
|
||||
|
||||
```sh
|
||||
llama-server -hf {namespace}/{model_name}-GGUF
|
||||
```
|
||||
|
||||
Then the endpoint can be accessed at http://localhost:8080/embedding, for
|
||||
example using `curl`:
|
||||
```console
|
||||
curl --request POST \
|
||||
--url http://localhost:8080/embedding \
|
||||
--header "Content-Type: application/json" \
|
||||
--data '{{"input": "Hello embeddings"}}' \
|
||||
--silent
|
||||
```
|
||||
|
||||
Alternatively, the `llama-embedding` command line tool can be used:
|
||||
```sh
|
||||
llama-embedding -hf {namespace}/{model_name}-GGUF --verbose-prompt -p "Hello embeddings"
|
||||
```
|
||||
|
||||
#### embd_normalize
|
||||
When a model uses pooling, or the pooling method is specified using `--pooling`,
|
||||
the normalization can be controlled by the `embd_normalize` parameter.
|
||||
|
||||
The default value is `2` which means that the embeddings are normalized using
|
||||
the Euclidean norm (L2). Other options are:
|
||||
* -1 No normalization
|
||||
* 0 Max absolute
|
||||
* 1 Taxicab
|
||||
* 2 Euclidean/L2
|
||||
* \>2 P-Norm
|
||||
|
||||
This can be passed in the request body to `llama-server`, for example:
|
||||
```sh
|
||||
--data '{{"input": "Hello embeddings", "embd_normalize": -1}}' \
|
||||
```
|
||||
|
||||
And for `llama-embedding`, by passing `--embd-normalize <value>`, for example:
|
||||
```sh
|
||||
llama-embedding -hf {namespace}/{model_name}-GGUF --embd-normalize -1 -p "Hello embeddings"
|
||||
```
|
||||
@@ -26,21 +26,31 @@ parser.add_argument('--namespace', '-ns', help='Namespace to add the model to',
|
||||
parser.add_argument('--org-base-model', '-b', help='Original Base model name', default="")
|
||||
parser.add_argument('--no-card', action='store_true', help='Skip creating model card')
|
||||
parser.add_argument('--private', '-p', action='store_true', help='Create private model')
|
||||
parser.add_argument('--embedding', '-e', action='store_true', help='Use embedding model card template')
|
||||
parser.add_argument('--dry-run', '-d', action='store_true', help='Print repository info and template without creating repository')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_id = f"{args.namespace}/{args.model_name}-GGUF"
|
||||
print("Repository ID: ", repo_id)
|
||||
|
||||
repo_url = api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
private=args.private,
|
||||
exist_ok=False
|
||||
)
|
||||
repo_url = None
|
||||
if not args.dry_run:
|
||||
repo_url = api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
private=args.private,
|
||||
exist_ok=False
|
||||
)
|
||||
|
||||
if not args.no_card:
|
||||
template_path = "scripts/readme.md.template"
|
||||
if args.embedding:
|
||||
template_path = "scripts/embedding/modelcard.template"
|
||||
else:
|
||||
template_path = "scripts/causal/modelcard.template"
|
||||
|
||||
print("Template path: ", template_path)
|
||||
|
||||
model_card_content = load_template_and_substitute(
|
||||
template_path,
|
||||
model_name=args.model_name,
|
||||
@@ -48,16 +58,21 @@ if not args.no_card:
|
||||
base_model=args.org_base_model,
|
||||
)
|
||||
|
||||
if model_card_content:
|
||||
api.upload_file(
|
||||
path_or_fileobj=model_card_content.encode('utf-8'),
|
||||
path_in_repo="README.md",
|
||||
repo_id=repo_id
|
||||
)
|
||||
print("Model card created successfully.")
|
||||
if args.dry_run:
|
||||
print("\nTemplate Content:\n")
|
||||
print(model_card_content)
|
||||
else:
|
||||
print("Failed to create model card.")
|
||||
if model_card_content:
|
||||
api.upload_file(
|
||||
path_or_fileobj=model_card_content.encode('utf-8'),
|
||||
path_in_repo="README.md",
|
||||
repo_id=repo_id
|
||||
)
|
||||
print("Model card created successfully.")
|
||||
else:
|
||||
print("Failed to create model card.")
|
||||
|
||||
print(f"Repository created: {repo_url}")
|
||||
if not args.dry_run and repo_url:
|
||||
print(f"Repository created: {repo_url}")
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ set -e
|
||||
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}"
|
||||
TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}"
|
||||
OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}"
|
||||
QUANTIZED_MODEL=$CONVERTED_MODEL
|
||||
|
||||
# Final check if we have a model path
|
||||
@@ -14,6 +16,11 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$QUANTIZED_TYPE" ]; then
|
||||
echo "Error: QUANTIZED_TYPE is required" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo $CONVERTED_MODEL
|
||||
|
||||
# Process the quantized model filename
|
||||
@@ -26,9 +33,16 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
cmake --build ../../build --target llama-quantize -j8
|
||||
|
||||
../../build/bin/llama-quantize $CONVERTED_MODEL $QUANTIZED_MODEL $QUANTIZED_TYPE
|
||||
echo $TOKEN_EMBD_TYPE
|
||||
echo $OUTPUT_TYPE
|
||||
|
||||
CMD_ARGS=("../../build/bin/llama-quantize")
|
||||
[[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE")
|
||||
[[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE")
|
||||
CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE")
|
||||
|
||||
"${CMD_ARGS[@]}"
|
||||
|
||||
echo "Quantized model saved to: $QUANTIZED_MODEL"
|
||||
|
||||
@@ -244,7 +244,7 @@ int main(int argc, char ** argv) {
|
||||
// stochastic verification
|
||||
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
|
||||
|
||||
auto & dist_tgt = *common_sampler_get_candidates(smpl);
|
||||
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
|
||||
|
||||
float p_tgt = 0.0f;
|
||||
float p_dft = 0.0f;
|
||||
@@ -493,7 +493,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
|
||||
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
|
||||
+1
-1
@@ -1,5 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
|
||||
project("ggml" C CXX)
|
||||
project("ggml" C CXX ASM)
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
@@ -307,6 +307,9 @@ extern "C" {
|
||||
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
||||
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
||||
|
||||
// Split graph without allocating it
|
||||
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
||||
|
||||
// Allocate and compute graph on the backend scheduler
|
||||
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
|
||||
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
||||
|
||||
@@ -512,6 +512,7 @@ extern "C" {
|
||||
GGML_OP_IM2COL,
|
||||
GGML_OP_IM2COL_BACK,
|
||||
GGML_OP_CONV_2D,
|
||||
GGML_OP_CONV_3D,
|
||||
GGML_OP_CONV_2D_DW,
|
||||
GGML_OP_CONV_TRANSPOSE_2D,
|
||||
GGML_OP_POOL_1D,
|
||||
@@ -1940,6 +1941,23 @@ extern "C" {
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
|
||||
struct ggml_tensor * b, // input [W, H, D, C * N]
|
||||
int s0, // stride
|
||||
int s1,
|
||||
int s2,
|
||||
int p0, // padding
|
||||
int p1,
|
||||
int p2,
|
||||
int d0, // dilation
|
||||
int d1,
|
||||
int d2,
|
||||
int n_channels,
|
||||
int n_batch,
|
||||
int n_channels_out);
|
||||
|
||||
enum ggml_op_pool {
|
||||
GGML_OP_POOL_MAX,
|
||||
GGML_OP_POOL_AVG,
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
// backend buffer type
|
||||
|
||||
const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
|
||||
GGML_ASSERT(buft);
|
||||
return buft->iface.get_name(buft);
|
||||
}
|
||||
|
||||
@@ -40,14 +41,17 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t
|
||||
return ggml_backend_buffer_init(buft, {}, NULL, 0);
|
||||
}
|
||||
|
||||
GGML_ASSERT(buft);
|
||||
return buft->iface.alloc_buffer(buft, size);
|
||||
}
|
||||
|
||||
size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
GGML_ASSERT(buft);
|
||||
return buft->iface.get_alignment(buft);
|
||||
}
|
||||
|
||||
size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
GGML_ASSERT(buft);
|
||||
// get_max_size is optional, defaults to SIZE_MAX
|
||||
if (buft->iface.get_max_size) {
|
||||
return buft->iface.get_max_size(buft);
|
||||
@@ -56,6 +60,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
}
|
||||
|
||||
size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(buft);
|
||||
// get_alloc_size is optional, defaults to ggml_nbytes
|
||||
if (buft->iface.get_alloc_size) {
|
||||
size_t size = buft->iface.get_alloc_size(buft, tensor);
|
||||
@@ -66,6 +71,7 @@ size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const s
|
||||
}
|
||||
|
||||
bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
|
||||
GGML_ASSERT(buft);
|
||||
if (buft->iface.is_host) {
|
||||
return buft->iface.is_host(buft);
|
||||
}
|
||||
@@ -73,6 +79,7 @@ bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
|
||||
}
|
||||
|
||||
ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) {
|
||||
GGML_ASSERT(buft);
|
||||
return buft->device;
|
||||
}
|
||||
|
||||
@@ -110,10 +117,12 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
|
||||
}
|
||||
|
||||
size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(buffer);
|
||||
return buffer->size;
|
||||
}
|
||||
|
||||
void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(buffer);
|
||||
// get_base is optional if the buffer is zero-sized
|
||||
if (buffer->size == 0) {
|
||||
return NULL;
|
||||
@@ -127,6 +136,7 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
}
|
||||
|
||||
enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(buffer);
|
||||
// init_tensor is optional
|
||||
if (buffer->iface.init_tensor) {
|
||||
return buffer->iface.init_tensor(buffer, tensor);
|
||||
@@ -135,6 +145,7 @@ enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, s
|
||||
}
|
||||
|
||||
void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
GGML_ASSERT(buffer);
|
||||
// clear is optional if the buffer is zero-sized
|
||||
if (buffer->size == 0) {
|
||||
return;
|
||||
@@ -160,6 +171,7 @@ bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
|
||||
}
|
||||
|
||||
void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
|
||||
GGML_ASSERT(buffer);
|
||||
buffer->usage = usage;
|
||||
|
||||
// FIXME: add a generic callback to the buffer interface
|
||||
@@ -169,14 +181,17 @@ void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backe
|
||||
}
|
||||
|
||||
enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(buffer);
|
||||
return buffer->usage;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(buffer);
|
||||
return buffer->buft;
|
||||
}
|
||||
|
||||
void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(buffer);
|
||||
if (buffer->iface.reset) {
|
||||
buffer->iface.reset(buffer);
|
||||
}
|
||||
@@ -215,6 +230,7 @@ void ggml_backend_free(ggml_backend_t backend) {
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
|
||||
GGML_ASSERT(backend);
|
||||
return ggml_backend_dev_buffer_type(backend->device);
|
||||
}
|
||||
|
||||
@@ -231,6 +247,8 @@ size_t ggml_backend_get_max_size(ggml_backend_t backend) {
|
||||
}
|
||||
|
||||
void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(backend);
|
||||
GGML_ASSERT(tensor);
|
||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
||||
|
||||
@@ -242,6 +260,8 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor *
|
||||
}
|
||||
|
||||
void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(backend);
|
||||
GGML_ASSERT(tensor);
|
||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
||||
|
||||
@@ -283,6 +303,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz
|
||||
}
|
||||
|
||||
void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
GGML_ASSERT(tensor);
|
||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
|
||||
if (size == 0) {
|
||||
@@ -298,6 +319,7 @@ void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size
|
||||
}
|
||||
|
||||
void ggml_backend_synchronize(ggml_backend_t backend) {
|
||||
GGML_ASSERT(backend);
|
||||
if (backend->iface.synchronize == NULL) {
|
||||
return;
|
||||
}
|
||||
@@ -306,18 +328,21 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
|
||||
}
|
||||
|
||||
ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
GGML_ASSERT(backend);
|
||||
GGML_ASSERT(backend->iface.graph_plan_create != NULL);
|
||||
|
||||
return backend->iface.graph_plan_create(backend, cgraph);
|
||||
}
|
||||
|
||||
void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
||||
GGML_ASSERT(backend);
|
||||
GGML_ASSERT(backend->iface.graph_plan_free != NULL);
|
||||
|
||||
backend->iface.graph_plan_free(backend, plan);
|
||||
}
|
||||
|
||||
enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
||||
GGML_ASSERT(backend);
|
||||
GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
|
||||
|
||||
return backend->iface.graph_plan_compute(backend, plan);
|
||||
@@ -330,22 +355,27 @@ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_
|
||||
}
|
||||
|
||||
enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
GGML_ASSERT(backend);
|
||||
return backend->iface.graph_compute(backend, cgraph);
|
||||
}
|
||||
|
||||
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||
GGML_ASSERT(backend);
|
||||
return ggml_backend_dev_supports_op(backend->device, op);
|
||||
}
|
||||
|
||||
bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
||||
GGML_ASSERT(backend);
|
||||
return ggml_backend_dev_supports_buft(backend->device, buft);
|
||||
}
|
||||
|
||||
bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||
GGML_ASSERT(backend);
|
||||
return ggml_backend_dev_offload_op(backend->device, op);
|
||||
}
|
||||
|
||||
ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
|
||||
GGML_ASSERT(backend);
|
||||
return backend->device;
|
||||
}
|
||||
|
||||
@@ -381,6 +411,7 @@ void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t b
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(backend_dst);
|
||||
if (backend_dst->iface.cpy_tensor_async != NULL) {
|
||||
if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
|
||||
return;
|
||||
@@ -412,18 +443,21 @@ void ggml_backend_event_free(ggml_backend_event_t event) {
|
||||
}
|
||||
|
||||
void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) {
|
||||
GGML_ASSERT(backend);
|
||||
GGML_ASSERT(backend->iface.event_record != NULL);
|
||||
|
||||
backend->iface.event_record(backend, event);
|
||||
}
|
||||
|
||||
void ggml_backend_event_synchronize(ggml_backend_event_t event) {
|
||||
GGML_ASSERT(event);
|
||||
GGML_ASSERT(event->device->iface.event_synchronize);
|
||||
|
||||
event->device->iface.event_synchronize(event->device, event);
|
||||
}
|
||||
|
||||
void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
|
||||
GGML_ASSERT(backend);
|
||||
GGML_ASSERT(backend->iface.event_wait != NULL);
|
||||
|
||||
backend->iface.event_wait(backend, event);
|
||||
@@ -432,18 +466,22 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
|
||||
// Backend device
|
||||
|
||||
const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
|
||||
GGML_ASSERT(device);
|
||||
return device->iface.get_name(device);
|
||||
}
|
||||
|
||||
const char * ggml_backend_dev_description(ggml_backend_dev_t device) {
|
||||
GGML_ASSERT(device);
|
||||
return device->iface.get_description(device);
|
||||
}
|
||||
|
||||
void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
|
||||
GGML_ASSERT(device);
|
||||
device->iface.get_memory(device, free, total);
|
||||
}
|
||||
|
||||
enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
|
||||
GGML_ASSERT(device);
|
||||
return device->iface.get_type(device);
|
||||
}
|
||||
|
||||
@@ -453,18 +491,22 @@ void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_d
|
||||
}
|
||||
|
||||
ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) {
|
||||
GGML_ASSERT(device);
|
||||
return device->reg;
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) {
|
||||
GGML_ASSERT(device);
|
||||
return device->iface.init_backend(device, params);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
||||
GGML_ASSERT(device);
|
||||
return device->iface.get_buffer_type(device);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
|
||||
GGML_ASSERT(device);
|
||||
if (device->iface.get_host_buffer_type == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
@@ -473,18 +515,22 @@ ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) {
|
||||
GGML_ASSERT(device);
|
||||
return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size);
|
||||
}
|
||||
|
||||
bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
|
||||
GGML_ASSERT(device);
|
||||
return device->iface.supports_op(device, op);
|
||||
}
|
||||
|
||||
bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) {
|
||||
GGML_ASSERT(device);
|
||||
return device->iface.supports_buft(device, buft);
|
||||
}
|
||||
|
||||
bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
|
||||
GGML_ASSERT(device);
|
||||
if (device->iface.offload_op != NULL) {
|
||||
return device->iface.offload_op(device, op);
|
||||
}
|
||||
@@ -495,18 +541,22 @@ bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_te
|
||||
// Backend (reg)
|
||||
|
||||
const char * ggml_backend_reg_name(ggml_backend_reg_t reg) {
|
||||
GGML_ASSERT(reg);
|
||||
return reg->iface.get_name(reg);
|
||||
}
|
||||
|
||||
size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) {
|
||||
GGML_ASSERT(reg);
|
||||
return reg->iface.get_device_count(reg);
|
||||
}
|
||||
|
||||
ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(reg);
|
||||
return reg->iface.get_device(reg, index);
|
||||
}
|
||||
|
||||
void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
GGML_ASSERT(reg);
|
||||
if (!reg->iface.get_proc_address) {
|
||||
return NULL;
|
||||
}
|
||||
@@ -521,6 +571,7 @@ struct ggml_backend_multi_buffer_context {
|
||||
};
|
||||
|
||||
static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(buffer);
|
||||
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
|
||||
for (size_t i = 0; i < ctx->n_buffers; i++) {
|
||||
ggml_backend_buffer_free(ctx->buffers[i]);
|
||||
@@ -531,6 +582,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
||||
}
|
||||
|
||||
static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
GGML_ASSERT(buffer);
|
||||
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
|
||||
for (size_t i = 0; i < ctx->n_buffers; i++) {
|
||||
ggml_backend_buffer_clear(ctx->buffers[i], value);
|
||||
@@ -566,10 +618,12 @@ ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer
|
||||
}
|
||||
|
||||
bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(buffer);
|
||||
return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer;
|
||||
}
|
||||
|
||||
void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
|
||||
GGML_ASSERT(buffer);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));
|
||||
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
|
||||
for (size_t i = 0; i < ctx->n_buffers; i++) {
|
||||
@@ -597,7 +651,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
|
||||
#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
|
||||
#define GGML_SCHED_MAX_SPLIT_INPUTS 30
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_COPIES
|
||||
@@ -848,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
|
||||
}
|
||||
|
||||
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
|
||||
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
|
||||
void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
|
||||
// reset splits
|
||||
sched->n_splits = 0;
|
||||
sched->n_graph_inputs = 0;
|
||||
@@ -1349,6 +1403,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
|
||||
GGML_ASSERT(sched);
|
||||
struct ggml_backend_sched_split * splits = sched->splits;
|
||||
|
||||
ggml_tensor * prev_ids_tensor = nullptr;
|
||||
@@ -1617,6 +1672,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
|
||||
}
|
||||
|
||||
void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
|
||||
GGML_ASSERT(sched);
|
||||
// reset state for the next run
|
||||
if (!sched->is_reset) {
|
||||
ggml_hash_set_reset(&sched->hash_set);
|
||||
@@ -1628,8 +1684,11 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
|
||||
}
|
||||
|
||||
bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
|
||||
GGML_ASSERT(sched);
|
||||
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
||||
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
ggml_backend_sched_synchronize(sched);
|
||||
|
||||
ggml_backend_sched_split_graph(sched, measure_graph);
|
||||
@@ -1644,6 +1703,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
|
||||
}
|
||||
|
||||
bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
|
||||
GGML_ASSERT(sched);
|
||||
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
|
||||
GGML_ASSERT(!sched->is_alloc);
|
||||
|
||||
@@ -1668,6 +1728,7 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
|
||||
}
|
||||
|
||||
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
|
||||
GGML_ASSERT(sched);
|
||||
if (!sched->is_reset && !sched->is_alloc) {
|
||||
ggml_backend_sched_reset(sched);
|
||||
}
|
||||
@@ -1682,6 +1743,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
|
||||
}
|
||||
|
||||
void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
|
||||
GGML_ASSERT(sched);
|
||||
for (int i = 0; i < sched->n_backends; i++) {
|
||||
ggml_backend_synchronize(sched->backends[i]);
|
||||
}
|
||||
@@ -1694,28 +1756,34 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
|
||||
}
|
||||
|
||||
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
|
||||
GGML_ASSERT(sched);
|
||||
sched->callback_eval = callback;
|
||||
sched->callback_eval_user_data = user_data;
|
||||
}
|
||||
|
||||
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
|
||||
GGML_ASSERT(sched);
|
||||
return sched->n_splits;
|
||||
}
|
||||
|
||||
int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
|
||||
GGML_ASSERT(sched);
|
||||
return sched->n_copies;
|
||||
}
|
||||
|
||||
int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
|
||||
GGML_ASSERT(sched);
|
||||
return sched->n_backends;
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
|
||||
GGML_ASSERT(sched);
|
||||
GGML_ASSERT(i >= 0 && i < sched->n_backends);
|
||||
return sched->backends[i];
|
||||
}
|
||||
|
||||
size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
|
||||
GGML_ASSERT(sched);
|
||||
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||
|
||||
@@ -1723,6 +1791,7 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
|
||||
}
|
||||
|
||||
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
||||
GGML_ASSERT(sched);
|
||||
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||
tensor_backend_id(node) = backend_index;
|
||||
@@ -1731,6 +1800,7 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
|
||||
GGML_ASSERT(sched);
|
||||
int backend_index = tensor_backend_id(node);
|
||||
if (backend_index == -1) {
|
||||
return NULL;
|
||||
@@ -1741,6 +1811,7 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched,
|
||||
// utils
|
||||
|
||||
enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
GGML_ASSERT(tensor->buffer == NULL);
|
||||
GGML_ASSERT(tensor->view_src != NULL);
|
||||
GGML_ASSERT(tensor->view_src->buffer != NULL);
|
||||
@@ -1752,6 +1823,7 @@ enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
|
||||
}
|
||||
|
||||
enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
|
||||
GGML_ASSERT(tensor);
|
||||
GGML_ASSERT(tensor->buffer == NULL);
|
||||
GGML_ASSERT(tensor->data == NULL);
|
||||
GGML_ASSERT(tensor->view_src == NULL);
|
||||
@@ -1825,6 +1897,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_
|
||||
}
|
||||
|
||||
struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
|
||||
GGML_ASSERT(graph);
|
||||
struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
|
||||
struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
|
||||
bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0]));
|
||||
@@ -1969,6 +2042,7 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
|
||||
// CPU backend - buffer
|
||||
|
||||
static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(buffer);
|
||||
uintptr_t data = (uintptr_t)buffer->context;
|
||||
|
||||
// align the buffer
|
||||
@@ -1980,28 +2054,33 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
}
|
||||
|
||||
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
GGML_ASSERT(buffer);
|
||||
ggml_aligned_free(buffer->context, buffer->size);
|
||||
}
|
||||
|
||||
static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
GGML_ASSERT(tensor);
|
||||
memset((char *)tensor->data + offset, value, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(tensor);
|
||||
memcpy((char *)tensor->data + offset, data, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(tensor);
|
||||
memcpy(data, (const char *)tensor->data + offset, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(src);
|
||||
if (ggml_backend_buffer_is_host(src->buffer)) {
|
||||
memcpy(dst->data, src->data, ggml_nbytes(src));
|
||||
return true;
|
||||
@@ -2012,6 +2091,7 @@ static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
|
||||
}
|
||||
|
||||
static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
GGML_ASSERT(buffer);
|
||||
memset(buffer->context, value, buffer->size);
|
||||
}
|
||||
|
||||
|
||||
+321
-354
@@ -70,6 +70,8 @@
|
||||
#include <aclnnop/aclnn_zero.h>
|
||||
#include <aclnnop/aclnn_index_copy.h>
|
||||
#include <aclnnop/aclnn_index_select.h>
|
||||
#include <aclnnop/aclnn_clamp.h>
|
||||
#include <aclnnop/aclnn_threshold.h>
|
||||
#include <float.h>
|
||||
|
||||
#include <cmath>
|
||||
@@ -964,8 +966,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
}
|
||||
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.f32_one_cache,
|
||||
ctx.f32_one_cache_element,
|
||||
&ctx.rms_norm_one_tensor_cache.cache,
|
||||
ctx.rms_norm_one_tensor_cache.size,
|
||||
src->ne,
|
||||
acl_gamma_nb,
|
||||
1, // dims
|
||||
@@ -980,8 +982,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
}
|
||||
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.f32_zero_cache,
|
||||
ctx.f32_zero_cache_element,
|
||||
&ctx.rms_norm_zero_tensor_cache.cache,
|
||||
ctx.rms_norm_zero_tensor_cache.size,
|
||||
src->ne,
|
||||
acl_rstd_nb,
|
||||
GGML_MAX_DIMS,
|
||||
@@ -1257,12 +1259,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
|
||||
|
||||
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
|
||||
if(acl_dst == nullptr) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
|
||||
} else {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
|
||||
}
|
||||
}
|
||||
|
||||
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
|
||||
if(acl_dst == nullptr) {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
|
||||
} else {
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
|
||||
@@ -1419,17 +1429,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
|
||||
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
|
||||
float m, int64_t size, float start, float stop, float step){
|
||||
int64_t ne[] = {size};
|
||||
size_t nb[] = {sizeof(float)};
|
||||
size_t nb[] = {sizeof(uint16_t)};
|
||||
|
||||
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
|
||||
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t));
|
||||
void* arange_buffer = arange_allocator.get();
|
||||
|
||||
aclTensor* arange_tensor = ggml_cann_create_tensor(
|
||||
arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
|
||||
arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
|
||||
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
|
||||
|
||||
aclTensor* slope_tensor = ggml_cann_create_tensor(
|
||||
slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
|
||||
slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
|
||||
|
||||
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
|
||||
|
||||
@@ -2221,9 +2231,41 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
|
||||
ggml_cann_release_resources(ctx, acl_index, acl_value);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Initializes and caches sine/cosine positional encoding values
|
||||
* (used in RoPE, Rotary Position Embedding) for attention layers.
|
||||
*
|
||||
* This function computes and caches the sin/cos values of
|
||||
* θ = position * theta_scale for RoPE encoding. The cache is shared
|
||||
* across attention layers, and only the first attention layer will
|
||||
* trigger initialization. The cache includes repeated sin/cos values
|
||||
* with different repeat methods depending on the @param is_neox flag.
|
||||
*
|
||||
* Steps performed by this function:
|
||||
* 1. Identify whether the target tensor belongs to Q/K in attention
|
||||
* and restrict computation to the first layer only.
|
||||
* 2. Initialize the theta scale array (arange → power → freq scaling).
|
||||
* 3. Allocate sin/cos caches if the max prompt length increases.
|
||||
* 4. Compute θ = position * theta_scale.
|
||||
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
|
||||
* 6. Expand sin/cos values by repeat or repeat_interleave depending
|
||||
* on whether @param is_neox is enabled.
|
||||
*
|
||||
* @param ctx The CANN backend context, holding memory pool,
|
||||
* stream, and persistent buffers for rope init/cache.
|
||||
* @param dst The destination ggml_tensor whose computation
|
||||
* depends on the RoPE values (usually Qcur/Kcur).
|
||||
* @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
|
||||
* @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
|
||||
* @param theta_scale Scalar exponent base for computing theta scale values.
|
||||
* @param freq_scale Frequency scaling factor, applied to theta scale.
|
||||
* @param attn_factor Attention scaling factor, applied to sin/cos.
|
||||
* @param is_neox Whether to use Neox-style repeat strategy
|
||||
* (dim expansion vs repeat_interleave).
|
||||
*/
|
||||
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
aclTensor* acl_cos_repeat_tensor,
|
||||
aclTensor* acl_sin_repeat_tensor,
|
||||
void* sin_tensor_buffer, void* cos_tensor_buffer,
|
||||
float* corr_dims, float ext_factor,
|
||||
float theta_scale, float freq_scale,
|
||||
float attn_factor, bool is_neox) {
|
||||
// int sin/cos cache, cache has different repeat method depond on
|
||||
@@ -2233,9 +2275,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
ggml_tensor* src1 = dst->src[1]; // position
|
||||
ggml_tensor* src2 = dst->src[2]; // freq_factors
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
int64_t theta_scale_length = ne00 / 2;
|
||||
int64_t theta_scale_length = src0->ne[0] / 2;
|
||||
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
|
||||
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
|
||||
theta_scale_length * sizeof(float_t)};
|
||||
@@ -2253,111 +2293,148 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
|
||||
}
|
||||
|
||||
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
|
||||
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
|
||||
// theta_scale arange, [0,1,...,ne00/2 - 1]
|
||||
aclTensor* acl_theta_scale_tensor = nullptr;
|
||||
// cache theta scale
|
||||
if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
|
||||
// theta_scale and freq_scale should not change during the current token inference process,
|
||||
// so we can directly use == here instead of comparing the absolute difference.
|
||||
ctx.rope_cache.theta_scale != theta_scale ||
|
||||
ctx.rope_cache.freq_scale != freq_scale) {
|
||||
|
||||
// used for accuracy testing
|
||||
bool is_attention = is_q || is_k;
|
||||
ctx.rope_cache.theta_scale_length = theta_scale_length;
|
||||
ctx.rope_cache.theta_scale = theta_scale;
|
||||
ctx.rope_cache.freq_scale = freq_scale;
|
||||
|
||||
if(ctx.init_ptr == nullptr || !is_attention) {
|
||||
// theta_scale arange, [0,1,...,ne00/2 - 1]
|
||||
if(ctx.init_ptr != nullptr){
|
||||
ACL_CHECK(aclrtFree(ctx.init_ptr));
|
||||
if (ctx.rope_cache.theta_scale_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
|
||||
}
|
||||
ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
|
||||
aclTensor* acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
|
||||
float start = 0;
|
||||
float step = 1;
|
||||
float stop = ne00 / 2;
|
||||
float n_elements = ne00 / 2;
|
||||
float stop = theta_scale_length;
|
||||
float n_elements = theta_scale_length;
|
||||
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
|
||||
|
||||
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
|
||||
aclTensor* acl_yarn_ramp_tensor = nullptr;
|
||||
if (ext_factor != 0) {
|
||||
// -rope_yarn_ramp
|
||||
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
|
||||
// return MIN(1, MAX(0, y)) - 1;
|
||||
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
|
||||
void* yarn_ramp_buffer = yarn_ramp_allocator.get();
|
||||
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
float zero_value = 0, one_value = 1;
|
||||
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
|
||||
aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
|
||||
aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT);
|
||||
aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT);
|
||||
aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT);
|
||||
aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor, one);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc);
|
||||
|
||||
// theta_interp = freq_scale * theta_extrap;
|
||||
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
|
||||
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
|
||||
//
|
||||
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
|
||||
// cache freq_scale + (freq_scale - 1) * ramp_mix
|
||||
float freq_scale_1 = freq_scale - 1;
|
||||
aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT);
|
||||
aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one);
|
||||
|
||||
ggml_cann_release_resources(ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc);
|
||||
}
|
||||
|
||||
// power
|
||||
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
|
||||
acl_theta_scale_tensor);
|
||||
|
||||
// freq_scale
|
||||
if (freq_scale != 1) {
|
||||
if (ext_factor != 0) {
|
||||
aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor);
|
||||
} else if (freq_scale != 1) {
|
||||
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
|
||||
}
|
||||
|
||||
// freq_factors
|
||||
if (src2) {
|
||||
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
|
||||
src2->data, ggml_cann_type_mapping(src2->type),
|
||||
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
|
||||
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
|
||||
}
|
||||
// release
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
|
||||
}
|
||||
|
||||
if(ctx.sin_ptr == nullptr) {
|
||||
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
|
||||
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
}
|
||||
if(position_length > ctx.max_prompt_length) {
|
||||
ctx.max_prompt_length = position_length;
|
||||
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
|
||||
ACL_CHECK(aclrtFree(ctx.sin_ptr));
|
||||
ACL_CHECK(aclrtFree(ctx.cos_ptr));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
|
||||
}
|
||||
|
||||
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
|
||||
|
||||
if(is_fisrt_layer || !is_attention) {
|
||||
|
||||
aclTensor* acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale);
|
||||
} else {
|
||||
// use cache
|
||||
acl_theta_scale_tensor =
|
||||
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
|
||||
// position
|
||||
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
|
||||
src1->data, ggml_cann_type_mapping(src1->type),
|
||||
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
|
||||
|
||||
// power * position
|
||||
int64_t theta_length = theta_scale_length * position_length;
|
||||
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* theta_buffer = theta_allocator.get();
|
||||
|
||||
aclTensor* acl_theta_tensor =
|
||||
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
theta_ne, theta_nb, GGML_MAX_DIMS);
|
||||
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
|
||||
acl_theta_tensor);
|
||||
|
||||
// sin/cos
|
||||
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
|
||||
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
|
||||
|
||||
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
|
||||
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
|
||||
|
||||
// release
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
|
||||
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
|
||||
}
|
||||
|
||||
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
|
||||
// freq_factors
|
||||
if (src2) {
|
||||
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float_t));
|
||||
void* freq_fac_res_ptr = freq_fac_res_allocator.get();
|
||||
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
|
||||
src2->data, ggml_cann_type_mapping(src2->type),
|
||||
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor(
|
||||
freq_fac_res_ptr, ACL_FLOAT, sizeof(float_t),
|
||||
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
|
||||
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
|
||||
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
|
||||
ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
|
||||
}
|
||||
|
||||
// position
|
||||
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
|
||||
src1->data, ggml_cann_type_mapping(src1->type),
|
||||
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
|
||||
|
||||
// power * position
|
||||
int64_t theta_length = theta_scale_length * position_length;
|
||||
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* theta_buffer = theta_allocator.get();
|
||||
|
||||
aclTensor* acl_theta_tensor =
|
||||
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
theta_ne, theta_nb, GGML_MAX_DIMS);
|
||||
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
|
||||
acl_theta_tensor);
|
||||
|
||||
// sin/cos
|
||||
ggml_cann_pool_alloc sin_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* sin_buffer = sin_allocator.get();
|
||||
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
|
||||
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
|
||||
|
||||
ggml_cann_pool_alloc cos_allocator(ctx.pool(),
|
||||
theta_length * sizeof(float_t));
|
||||
void* cos_buffer = cos_allocator.get();
|
||||
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
|
||||
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
|
||||
GGML_MAX_DIMS, ACL_FORMAT_ND);
|
||||
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
|
||||
|
||||
if (ext_factor != 0) {
|
||||
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
}
|
||||
|
||||
// attn_factor
|
||||
if (attn_factor != 1) {
|
||||
@@ -2365,6 +2442,19 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
|
||||
}
|
||||
|
||||
int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
|
||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||
sin_reshape_nb[0] = sizeof(float_t);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_sin_repeat_tensor =
|
||||
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_cos_repeat_tensor =
|
||||
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
|
||||
// repeat
|
||||
if (is_neox) {
|
||||
int64_t repeatsArray[] = {1, 1, 1, 2};
|
||||
@@ -2380,8 +2470,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
|
||||
num_repeats, output_size);
|
||||
}
|
||||
|
||||
// release
|
||||
ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
|
||||
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
|
||||
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
|
||||
acl_cos_repeat_tensor);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
@@ -2403,6 +2494,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
// TODO: use ascendc
|
||||
// Only test with LLAMA model.
|
||||
ggml_tensor* src0 = dst->src[0]; // input
|
||||
ggml_tensor* src1 = dst->src[1];
|
||||
|
||||
// param
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
@@ -2424,8 +2516,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
// TODO: n_dims <= ne0
|
||||
GGML_ASSERT(n_dims == ne0);
|
||||
GGML_ASSERT(n_dims % 2 == 0);
|
||||
// TODO: ext_factor != 0
|
||||
GGML_ASSERT(ext_factor == 0);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
@@ -2435,13 +2525,16 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
// init cos/sin cache
|
||||
ggml_cann_pool_alloc sin_allocator(
|
||||
ctx.pool(), ne00 * ne02 * sizeof(float_t));
|
||||
ggml_cann_pool_alloc cos_allocator(
|
||||
ctx.pool(), ne00 * ne02 * sizeof(float_t));
|
||||
void* sin_buffer = sin_allocator.get();
|
||||
void* cos_buffer = cos_allocator.get();
|
||||
// sin/cos tensor length.
|
||||
int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
|
||||
ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
|
||||
ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
|
||||
void *sin_tensor_buffer = sin_tensor_allocator.get();
|
||||
void *cos_tensor_buffer = cos_tensor_allocator.get();
|
||||
|
||||
// init ctx.rope_cos/rope_sin cache
|
||||
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
|
||||
theta_scale, freq_scale, attn_factor, is_neox);
|
||||
|
||||
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
|
||||
size_t sin_reshape_nb[GGML_MAX_DIMS];
|
||||
@@ -2450,13 +2543,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_sin_reshape_tensor =
|
||||
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclTensor* acl_cos_reshape_tensor =
|
||||
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
|
||||
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
|
||||
aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
|
||||
theta_scale, freq_scale, attn_factor, is_neox);
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
@@ -2825,174 +2916,49 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
*/
|
||||
static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
//dst [M, K, N, 1]
|
||||
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
|
||||
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
|
||||
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] -> [D, M, K, 1]
|
||||
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 -> [D, 1, K, 1]
|
||||
ggml_tensor * ids = dst->src[2]; //ids [K, N]
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
GGML_ASSERT(src0->ne[3] == 1);
|
||||
GGML_ASSERT(src1->ne[3] == 1);
|
||||
GGML_ASSERT(dst->ne[3] == 1);
|
||||
|
||||
// copy index from npu to cpu
|
||||
int64_t n_as = ne02; // A
|
||||
int64_t n_ids = ids->ne[0]; // K
|
||||
int64_t batch = src1->ne[2];
|
||||
GGML_ASSERT(batch == ids->ne[1]);
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
|
||||
ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
|
||||
ggml_cann_pool_alloc export_allocator(ctx.pool(), src0->ne[0] * src0->ne[1] * ids->ne[0] * ggml_element_size(src0));
|
||||
void* export_ptr = export_allocator.get();
|
||||
for (int64_t i = 0; i < batch; i++) {
|
||||
aclTensor *select_index = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]);
|
||||
aclTensor *export_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3);
|
||||
|
||||
char * src0_original = (char *) src0->data;
|
||||
char * src1_original = (char *) src1->data;
|
||||
char * dst_original = (char *) dst->data;
|
||||
size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03};
|
||||
|
||||
// src0 is F16, src1 is F32, dst is F32
|
||||
ggml_cann_pool_alloc src0_cast_allocator;
|
||||
if (src0->type == GGML_TYPE_F16) {
|
||||
src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0));
|
||||
void* src0_cast_buf = src0_cast_allocator.get();
|
||||
|
||||
size_t cast_nb[GGML_MAX_DIMS];
|
||||
cast_nb[0] = sizeof(float_t);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1];
|
||||
int64_t select_export_ne[] = {src0->ne[0], src0->ne[1], ids->ne[0]};
|
||||
size_t select_export_nb[3];
|
||||
select_export_nb[0] = src0->nb[0];
|
||||
for (int k = 1;k < 3; k++) {
|
||||
select_export_nb[k] = select_export_nb[k-1] * select_export_ne[k-1];
|
||||
}
|
||||
|
||||
aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf,
|
||||
ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
|
||||
ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16);
|
||||
aclTensor *select_export = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_export_ne, select_export_nb, 3);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, export_weight, 0, select_index, select_export);
|
||||
|
||||
src0_original = (char *) src0_cast_buf;
|
||||
memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb));
|
||||
int64_t select_transpose_ne[] = {select_export_ne[1], select_export_ne[0], select_export_ne[2]};
|
||||
size_t select_transpose_nb[] = {select_export_nb[1], select_export_nb[0], select_export_nb[2]};
|
||||
aclTensor *select_export_transpose = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_transpose_ne, select_transpose_nb, 3);
|
||||
|
||||
int64_t active_tensor_ne[] = {src1->ne[0], 1, src1->ne[1]};
|
||||
size_t active_tensor_nb[] = {src1->nb[0], src1->nb[1], src1->nb[1]};
|
||||
aclTensor *active_tensor = ggml_cann_create_tensor(src1, active_tensor_ne, active_tensor_nb, 3, ACL_FORMAT_ND, i * src1->nb[2]);
|
||||
|
||||
int64_t dst_ne[] = {dst->ne[0], 1, dst->ne[1]};
|
||||
size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[1]};
|
||||
aclTensor *acl_dst = ggml_cann_create_tensor(dst, dst_ne,dst_nb, 3, ACL_FORMAT_ND, i * dst->nb[2]);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, active_tensor, select_export_transpose, acl_dst, 2);
|
||||
|
||||
ggml_cann_release_resources(ctx, select_index, export_weight, select_export, active_tensor, acl_dst, select_export_transpose);
|
||||
}
|
||||
|
||||
#ifdef ASCEND_310P
|
||||
ggml_tensor src0_row = *src0;
|
||||
ggml_tensor src1_row = *src1;
|
||||
ggml_tensor dst_row = *dst;
|
||||
|
||||
if (src0->type == GGML_TYPE_F16) {
|
||||
src0_row.type = GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
// src0_row [D, M, 1, 1] weight without permute
|
||||
src0_row.ne[2] = 1;
|
||||
src0_row.ne[3] = 1;
|
||||
src0_row.nb[0] = ori_src0_nb[0];
|
||||
src0_row.nb[1] = ori_src0_nb[1];
|
||||
src0_row.nb[2] = ori_src0_nb[1];
|
||||
src0_row.nb[3] = ori_src0_nb[1];
|
||||
|
||||
// src1_row [D, 1, 1, 1] -> input
|
||||
src1_row.ne[1] = 1;
|
||||
src1_row.ne[2] = 1;
|
||||
src1_row.ne[3] = 1;
|
||||
src1_row.nb[2] = nb11;
|
||||
src1_row.nb[3] = nb11;
|
||||
|
||||
// dst_row [M, 1, 1, 1] -> out
|
||||
dst_row.ne[1] = 1;
|
||||
dst_row.ne[2] = 1;
|
||||
dst_row.ne[3] = 1;
|
||||
dst_row.nb[2] = nb1;
|
||||
dst_row.nb[3] = nb1;
|
||||
|
||||
//create weight for one row
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
// expert index
|
||||
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
|
||||
// If B = 1 (broadcast), always use 0; otherwise, use id.
|
||||
int64_t i11 = (ne11 == 1 ? 0 : id);
|
||||
int64_t i12 = iid1;
|
||||
|
||||
int64_t i1 = id;
|
||||
int64_t i2 = i12;
|
||||
|
||||
void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
|
||||
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
|
||||
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
|
||||
|
||||
src0_row.data = src0_tmp_ptr;
|
||||
src1_row.data = src1_tmp_ptr;
|
||||
dst_row.data = dst_tmp_ptr;
|
||||
dst_row.src[0] = &src0_row;
|
||||
dst_row.src[1] = &src1_row;
|
||||
|
||||
ggml_cann_mul_mat(ctx, &dst_row);
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
|
||||
std::vector<aclTensor*> src0_tensor_vec;
|
||||
std::vector<aclTensor*> src1_tensor_vec;
|
||||
std::vector<aclTensor*> dst_tensor_vec;
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
// src0_row [M, D] -> weight && permute
|
||||
int64_t src0_ne[2] = {ne01, ne00};
|
||||
size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]};
|
||||
// src1_row [D, 1] -> input
|
||||
int64_t src1_ne[2] = {ne10, 1};
|
||||
size_t src1_nb[2] = {nb10, nb11};
|
||||
// dst_row [M, 1] -> out
|
||||
int64_t dst_ne[2] = {ne0, 1};
|
||||
size_t dst_nb[2] = {nb0, nb1};
|
||||
|
||||
// expert index
|
||||
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
|
||||
// If B = 1 (broadcast), always use 0; otherwise, use id.
|
||||
int64_t i11 = (ne11 == 1 ? 0 : id);
|
||||
int64_t i12 = iid1;
|
||||
|
||||
int64_t i1 = id;
|
||||
int64_t i2 = i12;
|
||||
|
||||
void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
|
||||
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
|
||||
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
|
||||
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr,
|
||||
ACL_FLOAT, sizeof(float),
|
||||
src0_ne, src0_nb, 2);
|
||||
aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr,
|
||||
ACL_FLOAT, sizeof(float),
|
||||
src1_ne, src1_nb, 2);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr,
|
||||
ACL_FLOAT, sizeof(float),
|
||||
dst_ne, dst_nb, 2);
|
||||
|
||||
src0_tensor_vec.push_back(acl_src0);
|
||||
src1_tensor_vec.push_back(acl_src1);
|
||||
dst_tensor_vec.push_back(acl_dst);
|
||||
}
|
||||
}
|
||||
|
||||
size_t GROUP_SIZE = 128;
|
||||
// GroupedMatmulV3 required tensor_list.size < 128
|
||||
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
|
||||
// split and call GroupedMatmulV3
|
||||
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
|
||||
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
|
||||
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
|
||||
std::vector<aclTensor*> dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end);
|
||||
|
||||
aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size());
|
||||
aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
|
||||
aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV3, src1_tensor_list, src0_tensor_list,
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
|
||||
|
||||
ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -3141,11 +3107,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
|
||||
ggml_tensor* src0 = dst->src[0]; // q, fp32
|
||||
ggml_tensor* src1 = dst->src[1]; // k, fp16
|
||||
ggml_tensor* src2 = dst->src[2]; // v, fp16
|
||||
ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
|
||||
ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
|
||||
ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
|
||||
ggml_tensor* src3 = dst->src[3]; // mask, fp16
|
||||
|
||||
// B, N, S, D (uncont) -> B, S, N, D (cont)
|
||||
int64_t src0_bsnd_ne[GGML_MAX_DIMS];
|
||||
memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t));
|
||||
size_t src0_bsnd_nb[GGML_MAX_DIMS];
|
||||
memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t));
|
||||
int64_t src1_bsnd_ne[GGML_MAX_DIMS];
|
||||
memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t));
|
||||
size_t src1_bsnd_nb[GGML_MAX_DIMS];
|
||||
memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t));
|
||||
int64_t src2_bsnd_ne[GGML_MAX_DIMS];
|
||||
memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t));
|
||||
size_t src2_bsnd_nb[GGML_MAX_DIMS];
|
||||
memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t));
|
||||
|
||||
auto transpose12 = [](int64_t* ne, size_t* nb) {
|
||||
int64_t ne_tmp = ne[1];
|
||||
size_t nb_tmp = nb[1];
|
||||
ne[1] = ne[2];
|
||||
nb[1] = nb[2];
|
||||
ne[2] = ne_tmp;
|
||||
nb[2] = nb_tmp;
|
||||
};
|
||||
|
||||
transpose12(src0_bsnd_ne, src0_bsnd_nb);
|
||||
transpose12(src1_bsnd_ne, src1_bsnd_nb);
|
||||
transpose12(src2_bsnd_ne, src2_bsnd_nb);
|
||||
|
||||
float maxBias = 0.0f;
|
||||
float scaleValue = 1.0f;
|
||||
float logitSoftcap = 0.0f;
|
||||
@@ -3167,11 +3160,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
void* src0_f16_buffer = nullptr;
|
||||
|
||||
if(ggml_cann_type_mapping(src0->type) != faDataType){
|
||||
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
|
||||
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
|
||||
src0_bsnd_nb, GGML_MAX_DIMS);
|
||||
src0_f16_buffer = src0_f16_allocator.alloc(
|
||||
ggml_nelements(src0) * faElemSize);
|
||||
|
||||
int64_t* src0_f16_ne = src0->ne;
|
||||
int64_t* src0_f16_ne = src0_bsnd_ne;
|
||||
size_t src0_f16_nb[GGML_MAX_DIMS];
|
||||
src0_f16_nb[0] = sizeof(uint16_t);
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
@@ -3185,20 +3179,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
|
||||
ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
|
||||
}else{
|
||||
acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
|
||||
acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
|
||||
src0_bsnd_nb, GGML_MAX_DIMS);
|
||||
}
|
||||
|
||||
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
|
||||
// and the direct output from FusedInferAttention
|
||||
|
||||
acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
|
||||
acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
|
||||
acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne,
|
||||
src1_bsnd_nb, GGML_MAX_DIMS);
|
||||
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
|
||||
src2_bsnd_nb, GGML_MAX_DIMS);
|
||||
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
void* out_f16_buffer = out_f16_allocator.alloc(
|
||||
ggml_nelements(dst) * faElemSize);
|
||||
|
||||
int64_t* out_f16_ne = src0->ne;
|
||||
int64_t* out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
@@ -3212,88 +3209,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
|
||||
// Step 3: create the PSEShift tensor if needed
|
||||
// this tensor is considered as mask (f16) in the llama.cpp
|
||||
|
||||
aclTensor* bcast_pse_tensor = nullptr;
|
||||
int64_t bcast_pse_ne[GGML_MAX_DIMS];
|
||||
size_t bcast_pse_nb[GGML_MAX_DIMS];
|
||||
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
|
||||
void* bcast_pse_buffer = nullptr;
|
||||
|
||||
if(src3 != nullptr){
|
||||
bcast_pse_buffer = bcast_pse_allocator.alloc(
|
||||
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
|
||||
// Construct the truncated pse tensor (common for prefill/decode)
|
||||
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
|
||||
src3->ne[0], // D
|
||||
src0->ne[1], // S (number of Q tokens)
|
||||
src3->ne[2], // mask N
|
||||
src3->ne[3] // B
|
||||
};
|
||||
size_t* trunc_pse_nb = src3->nb;
|
||||
|
||||
if(src0->ne[1] > 1){
|
||||
// Case 1: broadcast pse for prefill stage with multiple head
|
||||
aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
|
||||
bcast_pse_ne[0] = src3->ne[0];
|
||||
bcast_pse_ne[1] = src3->ne[1];
|
||||
bcast_pse_ne[2] = src0->ne[2];
|
||||
bcast_pse_ne[3] = src3->ne[3];
|
||||
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
|
||||
src3->data, ACL_FLOAT16, sizeof(uint16_t),
|
||||
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
|
||||
);
|
||||
|
||||
int64_t bcast_pse_ne[GGML_MAX_DIMS];
|
||||
size_t bcast_pse_nb[GGML_MAX_DIMS];
|
||||
bcast_pse_ne[0] = src3->ne[0]; // D
|
||||
bcast_pse_ne[1] = src0->ne[1]; // S
|
||||
bcast_pse_ne[2] = src0->ne[2]; // N (num_heads)
|
||||
bcast_pse_ne[3] = src3->ne[3]; // B
|
||||
if (maxBias == 0.0f) {
|
||||
// When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
|
||||
// Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
|
||||
bcast_pse_nb[0] = sizeof(uint16_t);
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
|
||||
}
|
||||
bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];
|
||||
bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data
|
||||
bcast_pse_nb[3] = src3->nb[3];
|
||||
|
||||
bcast_pse_tensor = ggml_cann_create_tensor(
|
||||
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
|
||||
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
|
||||
|
||||
int64_t repeats[] = {1, src0->ne[2], 1, 1};
|
||||
aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
|
||||
}else{
|
||||
// Case 2: trunc the first row and broadcast pse for decode stage with multiple head
|
||||
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
|
||||
size_t* trunc_pse_nb = src3->nb;
|
||||
|
||||
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
|
||||
src3->data, ACL_FLOAT16, sizeof(uint16_t),
|
||||
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
|
||||
|
||||
bcast_pse_ne[0] = src3->ne[0];
|
||||
bcast_pse_ne[1] = src0->ne[1];
|
||||
bcast_pse_ne[2] = src0->ne[2];
|
||||
bcast_pse_ne[3] = src3->ne[3];
|
||||
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
|
||||
);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
|
||||
} else {
|
||||
bcast_pse_nb[0] = sizeof(uint16_t);
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
|
||||
}
|
||||
|
||||
void* bcast_pse_buffer = bcast_pse_allocator.alloc(
|
||||
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)
|
||||
);
|
||||
|
||||
bcast_pse_tensor = ggml_cann_create_tensor(
|
||||
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
|
||||
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
|
||||
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
|
||||
);
|
||||
|
||||
int64_t repeats[] = {1, src0->ne[2], 1, 1};
|
||||
aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
|
||||
}
|
||||
|
||||
// Compute the slope if needed. Derived from ggml_cann_softmax().
|
||||
if(maxBias != 0.0f){
|
||||
// alibi
|
||||
// Compute the slope if needed. Derived from ggml_cann_softmax().
|
||||
const int64_t n_heads = src0->ne[2];
|
||||
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
|
||||
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
|
||||
void* slope_buffer = slope_allocator.get();
|
||||
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
|
||||
|
||||
int64_t slope_ne[] = {1, 1, n_heads, 1};
|
||||
size_t slope_nb[GGML_MAX_DIMS];
|
||||
slope_nb[0] = sizeof(float);
|
||||
slope_nb[0] = sizeof(uint16_t);
|
||||
for(int i = 1;i<GGML_MAX_DIMS;i++) {
|
||||
slope_nb[i] = slope_nb[i-1] * slope_ne[0];
|
||||
}
|
||||
|
||||
aclTensor* slope_tensor = ggml_cann_create_tensor(
|
||||
slope_buffer, ACL_FLOAT, sizeof(float),
|
||||
slope_buffer, ACL_FLOAT16, sizeof(uint16_t),
|
||||
slope_ne, slope_nb, GGML_MAX_DIMS);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
|
||||
|
||||
ggml_cann_release_resources(ctx, slope_tensor);
|
||||
ggml_cann_release_resources(ctx, slope_tensor, acl_mask_f16_trunc_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3310,7 +3300,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
// double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
|
||||
int64_t preTokens = 65535;
|
||||
int64_t nextTokens = 65535;
|
||||
char layout[5] = {'B', 'N', 'S', 'D', 0};
|
||||
char layout[5] = {'B', 'S', 'N', 'D', 0};
|
||||
int64_t sparseMode = 0;
|
||||
int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
|
||||
int64_t blockSize = 0;
|
||||
@@ -3347,32 +3337,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
);
|
||||
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
|
||||
int64_t new_dim[] = {0, 2, 1, 3};
|
||||
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
|
||||
if(ggml_cann_type_mapping(dst->type) != faDataType){
|
||||
ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
|
||||
perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
|
||||
void* perm_out_f16_buffer = perm_out_f16_allocator.get();
|
||||
|
||||
int64_t* perm_out_f16_ne = dst->ne;
|
||||
size_t perm_out_f16_nb[GGML_MAX_DIMS];
|
||||
perm_out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
|
||||
perm_out_f16_buffer, faDataType, faElemSize,
|
||||
perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
|
||||
aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx,
|
||||
acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
|
||||
}else{
|
||||
// only need to permute
|
||||
aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
|
||||
}
|
||||
// TODO: when dst is fp16, don't need cast
|
||||
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
||||
acl_src1_f16_tensor,
|
||||
acl_src2_f16_tensor,
|
||||
|
||||
+37
-26
@@ -360,6 +360,30 @@ struct ggml_cann_graph {
|
||||
};
|
||||
#endif // USE_ACL_GRAPH
|
||||
|
||||
struct ggml_cann_rope_cache {
|
||||
~ggml_cann_rope_cache() {
|
||||
if(theta_scale_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(theta_scale_cache));
|
||||
}
|
||||
}
|
||||
|
||||
void* theta_scale_cache = nullptr;
|
||||
int64_t theta_scale_length = 0;
|
||||
float theta_scale = 0.0f;
|
||||
float freq_scale = 0.0f;
|
||||
};
|
||||
|
||||
struct ggml_cann_tensor_cache {
|
||||
~ggml_cann_tensor_cache() {
|
||||
if(cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(cache));
|
||||
}
|
||||
}
|
||||
|
||||
void* cache = nullptr;
|
||||
int64_t size = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Context for managing CANN backend operations.
|
||||
*/
|
||||
@@ -368,21 +392,18 @@ struct ggml_backend_cann_context {
|
||||
std::string name; /**< Name of the device. */
|
||||
std::string description; /**< Description of the device. */
|
||||
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
|
||||
void* init_ptr = nullptr;
|
||||
void* sin_ptr = nullptr;
|
||||
void* cos_ptr = nullptr;
|
||||
int64_t max_prompt_length = 65536;
|
||||
#ifdef USE_ACL_GRAPH
|
||||
/// Cached CANN ACL graph used for executing the current ggml computation graph.
|
||||
std::unique_ptr<ggml_cann_graph> cann_graph;
|
||||
bool acl_graph_mode = true;
|
||||
#endif
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
bool support_set_rows;
|
||||
void* f32_zero_cache = nullptr;
|
||||
void* f32_one_cache = nullptr;
|
||||
int64_t f32_zero_cache_element = 0;
|
||||
int64_t f32_one_cache_element = 0;
|
||||
// Rope Cache
|
||||
ggml_cann_rope_cache rope_cache;
|
||||
// Constant Pool
|
||||
ggml_cann_tensor_cache rms_norm_one_tensor_cache;
|
||||
ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
|
||||
|
||||
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
|
||||
|
||||
@@ -398,14 +419,13 @@ struct ggml_backend_cann_context {
|
||||
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
|
||||
device, async_mode ? "ON" : "OFF");
|
||||
|
||||
support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
|
||||
GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
|
||||
|
||||
if (!support_set_rows) {
|
||||
GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
|
||||
"Falling back to eager mode.\n", __func__);
|
||||
}
|
||||
#ifdef USE_ACL_GRAPH
|
||||
acl_graph_mode = !(parse_bool(get_env("GGML_CANN_DISABLE_ACL_GRAPH").value_or("")));
|
||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n",
|
||||
__func__, device,
|
||||
acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -422,15 +442,6 @@ struct ggml_backend_cann_context {
|
||||
ACL_CHECK(aclrtDestroyStream(streams[i]));
|
||||
}
|
||||
}
|
||||
if(init_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(init_ptr));
|
||||
}
|
||||
if(sin_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(sin_ptr));
|
||||
}
|
||||
if(cos_ptr != nullptr) {
|
||||
ACL_CHECK(aclrtFree(cos_ptr));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1155,7 +1155,7 @@ namespace {
|
||||
* @note The workspace buffer used in this function is managed globally and reused
|
||||
* across calls. This reduces overhead from repeated memory allocation and deallocation.
|
||||
*/
|
||||
static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) {
|
||||
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
|
||||
aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
|
||||
tensor->nb, 2, ACL_FORMAT_ND, offset);
|
||||
uint64_t workspaceSize = 0;
|
||||
@@ -1203,7 +1203,7 @@ static void ggml_backend_cann_buffer_set_tensor(
|
||||
if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
|
||||
GGML_ASSERT(tensor->ne[2] == 1);
|
||||
GGML_ASSERT(tensor->ne[3] == 1);
|
||||
weight_format_to_nz(tensor, data, offset);
|
||||
weight_format_to_nz(tensor, offset);
|
||||
}
|
||||
} else {
|
||||
void *transform_buffer = malloc(size);
|
||||
@@ -2247,12 +2247,12 @@ static enum ggml_status ggml_backend_cann_graph_compute(
|
||||
(ggml_backend_cann_context*)backend->context;
|
||||
ggml_cann_set_device(cann_ctx->device);
|
||||
release_nz_workspace();
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
bool cann_graph_update_required = false;
|
||||
|
||||
// check environment LLAMA_SET_ROWS
|
||||
if (!cann_ctx->support_set_rows) {
|
||||
if (!cann_ctx->acl_graph_mode) {
|
||||
use_cann_graph = false;
|
||||
}
|
||||
|
||||
@@ -2336,7 +2336,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
#ifdef ASCEND_310P
|
||||
// Q4 && Q8 per group is not suppor on 310p device
|
||||
// Q4 && Q8 per group is not support on 310p device
|
||||
return false;
|
||||
#endif
|
||||
// only support contiguous for quantized types.
|
||||
@@ -2354,7 +2354,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
#ifdef ASCEND_310P
|
||||
// Q4 && Q8 per group is not suppor on 310p device
|
||||
// Q4 && Q8 per group is not support on 310p device
|
||||
return false;
|
||||
#endif
|
||||
// only support contiguous for quantized types.
|
||||
@@ -2405,16 +2405,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
}
|
||||
case GGML_OP_ROPE: {
|
||||
// TODO: with ops-test v == 1
|
||||
float ext_factor = 0.0f;
|
||||
memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
|
||||
// TODO: n_dims <= ne0
|
||||
if (op->src[0]->ne[0] != op->op_params[1]) {
|
||||
return false;
|
||||
}
|
||||
// TODO: ext_factor != 0
|
||||
if (ext_factor != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
@@ -2424,9 +2418,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!ggml_is_contiguous(op->src[0])){
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_UPSCALE: {
|
||||
@@ -2496,7 +2487,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
return true;
|
||||
case GGML_OP_SCALE:
|
||||
float bias;
|
||||
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
|
||||
memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float));
|
||||
return bias == 0.0f; // TODO: support bias != 0.0f
|
||||
case GGML_OP_SOFT_MAX:
|
||||
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
||||
@@ -2505,6 +2496,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
}
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:{
|
||||
#ifdef ASCEND_310P
|
||||
// FA not support on 310p device
|
||||
return false;
|
||||
#endif
|
||||
// derived from [ggml-cuda.cu]
|
||||
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
|
||||
return false;
|
||||
@@ -2523,15 +2518,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
// different head sizes of K and V are not supported yet
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek MLA
|
||||
if (op->src[0]->ne[0] % 16 != 0) {
|
||||
// TODO: padding to support
|
||||
return false;
|
||||
}
|
||||
float logitSoftcap = 0.0f;
|
||||
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
|
||||
memcpy(&logitSoftcap, (const float *)(op->op_params) + 2, sizeof(float));
|
||||
if(logitSoftcap != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -435,7 +435,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
)
|
||||
if (GGML_RVV)
|
||||
if (GGML_XTHEADVECTOR)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d)
|
||||
elseif (GGML_RV_ZFH)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
|
||||
else()
|
||||
@@ -497,9 +497,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.11.0")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.13.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
@@ -555,6 +555,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
|
||||
@@ -576,7 +577,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
|
||||
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
|
||||
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
|
||||
endif()
|
||||
|
||||
|
||||
@@ -150,8 +150,6 @@
|
||||
#elif defined(__s390x__)
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
||||
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
|
||||
@@ -23,6 +23,27 @@
|
||||
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
||||
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
||||
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
|
||||
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
|
||||
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
|
||||
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
|
||||
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
|
||||
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
|
||||
|
||||
// precomputed tables for expanding 8bits to 8 bytes:
|
||||
static const __attribute__((aligned(16))) uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b ) << 4
|
||||
static const __attribute__((aligned(16))) uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
||||
|
||||
// permute mask for byteswapping
|
||||
static const uint8x16_t v_kperm = (const uint8x16_t){
|
||||
7, 6, 5, 4, 3, 2, 1, 0,
|
||||
15, 14, 13, 12, 11, 10, 9, 8
|
||||
};
|
||||
#endif
|
||||
|
||||
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK8_0 == 32);
|
||||
assert(k % QK8_0 == 0);
|
||||
@@ -241,6 +262,301 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(qk == QK5_0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
int ib = 0;
|
||||
float sumf = 0.0f;
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
float32x4_t v_sum0 = vec_splats(0.0f);
|
||||
float32x4_t v_sum1 = vec_splats(0.0f);
|
||||
|
||||
uint32_t qh0, qh1;
|
||||
uint64_t tmp0[4], tmp1[4];
|
||||
|
||||
const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib + 1 < nb; ib += 2) {
|
||||
const block_q5_0 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||
const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||
|
||||
memcpy(&qh0, x0->qh, sizeof(qh0));
|
||||
memcpy(&qh1, x1->qh, sizeof(qh1));
|
||||
|
||||
tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
|
||||
tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
|
||||
tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
|
||||
tmp0[3] = table_b2b_1[(qh0 >> 24) ];
|
||||
|
||||
tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
|
||||
tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
|
||||
tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
|
||||
tmp1[3] = table_b2b_1[(qh1 >> 24) ];
|
||||
|
||||
int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));
|
||||
int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));
|
||||
int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));
|
||||
int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);
|
||||
v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);
|
||||
v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);
|
||||
v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);
|
||||
|
||||
const uint8x16_t v_x0 = vec_xl(0, (const uint8_t *)x0->qs);
|
||||
const uint8x16_t v_x1 = vec_xl(0, (const uint8_t *)x1->qs);
|
||||
|
||||
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
|
||||
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
|
||||
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
|
||||
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
|
||||
|
||||
const int8x16_t v_x0lf = vec_sub(v_x0l, v_qh0l);
|
||||
const int8x16_t v_x0hf = vec_sub(v_x0h, v_qh0h);
|
||||
const int8x16_t v_x1lf = vec_sub(v_x1l, v_qh1l);
|
||||
const int8x16_t v_x1hf = vec_sub(v_x1h, v_qh1h);
|
||||
|
||||
const int8x16_t v_y0l = vec_xl(0, (const int8_t *)y0->qs);
|
||||
const int8x16_t v_y0h = vec_xl(QK8_0/2, (const int8_t *)y0->qs);
|
||||
const int8x16_t v_y1l = vec_xl(0, (const int8_t *)y1->qs);
|
||||
const int8x16_t v_y1h = vec_xl(QK8_0/2, (const int8_t *)y1->qs);
|
||||
|
||||
const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);
|
||||
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);
|
||||
|
||||
const float32x4_t v_xy0f = vec_float(v_xy0);
|
||||
const float32x4_t v_xy1f = vec_float(v_xy1);
|
||||
|
||||
const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));
|
||||
|
||||
v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);
|
||||
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
|
||||
}
|
||||
|
||||
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib < nb; ++ib) {
|
||||
const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
|
||||
const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x0->qh, sizeof(qh));
|
||||
|
||||
uint64_t tmp[4];
|
||||
tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
|
||||
tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
|
||||
tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
|
||||
tmp[3] = table_b2b_1[(qh >> 24) ];
|
||||
|
||||
int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));
|
||||
int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);
|
||||
v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, (const uint8_t *)x0->qs);
|
||||
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
|
||||
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
|
||||
|
||||
const int8x16_t v_xlf = vec_sub(v_xl, v_qhl);
|
||||
const int8x16_t v_xhf = vec_sub(v_xh, v_qhh);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0, (const int8_t *)y0->qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_0/2, (const int8_t *)y0->qs);
|
||||
|
||||
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);
|
||||
const float32x4_t v_xyf = vec_float(v_xy);
|
||||
|
||||
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f));
|
||||
|
||||
sumf += vec_hsum(v_acc);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(x);
|
||||
UNUSED(y);
|
||||
UNUSED(ib);
|
||||
UNUSED(sumf);
|
||||
ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_1;
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(qk == QK5_1);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_1 * GGML_RESTRICT x = vx;
|
||||
const block_q8_1 * GGML_RESTRICT y = vy;
|
||||
|
||||
int ib = 0;
|
||||
float sumf = 0.0f;
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
float32x4_t v_sum0 = vec_splats(0.0f);
|
||||
float32x4_t v_sum1 = vec_splats(0.0f);
|
||||
|
||||
float summs0 = 0.0f;
|
||||
float summs1 = 0.0f;
|
||||
|
||||
uint32_t qh0;
|
||||
uint32_t qh1;
|
||||
|
||||
uint64_t tmp0[4];
|
||||
uint64_t tmp1[4];
|
||||
|
||||
const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib + 1 < nb; ib += 2) {
|
||||
const block_q5_1 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||
const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||
const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||
const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||
|
||||
summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
|
||||
summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s);
|
||||
|
||||
memcpy(&qh0, x0->qh, sizeof(qh0));
|
||||
memcpy(&qh1, x1->qh, sizeof(qh1));
|
||||
|
||||
tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
|
||||
tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
|
||||
tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
|
||||
tmp0[3] = table_b2b_0[(qh0 >> 24) ];
|
||||
|
||||
tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
|
||||
tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
|
||||
tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
|
||||
tmp1[3] = table_b2b_0[(qh1 >> 24) ];
|
||||
|
||||
int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0));
|
||||
int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2));
|
||||
int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0));
|
||||
int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm);
|
||||
v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm);
|
||||
v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm);
|
||||
v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm);
|
||||
|
||||
const uint8x16_t v_x0 = vec_xl(0, x0->qs);
|
||||
const uint8x16_t v_x1 = vec_xl(0, x1->qs);
|
||||
|
||||
const int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
|
||||
const int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
|
||||
const int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
|
||||
const int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
|
||||
|
||||
const int8x16_t v_x0lf = vec_or(v_x0l, v_qh0l);
|
||||
const int8x16_t v_x0hf = vec_or(v_x0h, v_qh0h);
|
||||
const int8x16_t v_x1lf = vec_or(v_x1l, v_qh1l);
|
||||
const int8x16_t v_x1hf = vec_or(v_x1h, v_qh1h);
|
||||
|
||||
const int8x16_t v_y0l = vec_xl(0 , y0->qs);
|
||||
const int8x16_t v_y0h = vec_xl(QK8_1/2, y0->qs);
|
||||
const int8x16_t v_y1l = vec_xl(0 , y1->qs);
|
||||
const int8x16_t v_y1h = vec_xl(QK8_1/2, y1->qs);
|
||||
|
||||
const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h);
|
||||
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h);
|
||||
|
||||
const float32x4_t v_xy0f = vec_float(v_xy0);
|
||||
const float32x4_t v_xy1f = vec_float(v_xy1);
|
||||
|
||||
const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d));
|
||||
|
||||
v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0);
|
||||
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
|
||||
}
|
||||
|
||||
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1) + summs0 + summs1;
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (; ib < nb; ++ib) {
|
||||
const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
|
||||
const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
|
||||
|
||||
float summs = GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, x0->qh, sizeof(qh));
|
||||
|
||||
uint64_t tmp[4];
|
||||
tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
|
||||
tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
|
||||
tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
|
||||
tmp[3] = table_b2b_0[(qh >> 24) ];
|
||||
|
||||
int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0));
|
||||
int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2));
|
||||
|
||||
// required for fixing the byteorder
|
||||
v_qhl = vec_perm(v_qhl, v_qhl, v_kperm);
|
||||
v_qhh = vec_perm(v_qhh, v_qhh, v_kperm);
|
||||
|
||||
const uint8x16_t v_x = vec_xl(0, x0->qs);
|
||||
const int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
|
||||
const int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
|
||||
|
||||
const int8x16_t v_xlf = vec_or(v_xl, v_qhl);
|
||||
const int8x16_t v_xhf = vec_or(v_xh, v_qhh);
|
||||
|
||||
const int8x16_t v_yl = vec_xl(0 , y0->qs);
|
||||
const int8x16_t v_yh = vec_xl(QK8_1/2, y0->qs);
|
||||
|
||||
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh);
|
||||
const float32x4_t v_xyf = vec_float(v_xy);
|
||||
|
||||
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
|
||||
const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc);
|
||||
|
||||
sumf += vec_hsum(v_acc) + summs;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(x);
|
||||
UNUSED(y);
|
||||
UNUSED(ib);
|
||||
UNUSED(sumf);
|
||||
ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
@@ -486,6 +486,14 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
|
||||
return v_abo + v_abe;
|
||||
}
|
||||
|
||||
/**
|
||||
* @see https://github.com/ggml-org/llama.cpp/pull/14037
|
||||
*/
|
||||
inline static float vec_hsum(float32x4_t v) {
|
||||
float32x4_t v_temp = v + vec_reve(v);
|
||||
return v_temp[0] + v_temp[1];
|
||||
}
|
||||
|
||||
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
||||
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
|
||||
return acc + (vec_unpackh(p) + vec_unpackl(p));
|
||||
|
||||
@@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_conv_2d(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CONV_3D:
|
||||
{
|
||||
ggml_compute_forward_conv_3d(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
{
|
||||
ggml_compute_forward_conv_2d_dw(params, tensor);
|
||||
@@ -2252,6 +2256,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_3D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
@@ -2773,6 +2778,7 @@ struct ggml_cplan ggml_graph_plan(
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_3D:
|
||||
{
|
||||
cur = GGML_IM2COL_WORK_SIZE;
|
||||
} break;
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
||||
|
||||
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
||||
@@ -127,6 +128,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
@@ -141,7 +148,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
@@ -173,6 +180,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
@@ -187,7 +200,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
@@ -222,6 +235,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
},
|
||||
/* DOTPROD GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
@@ -236,7 +255,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
@@ -270,6 +289,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
},
|
||||
/* i8mm GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
@@ -284,7 +309,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
@@ -319,6 +344,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
},
|
||||
/* i8mm GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
@@ -333,7 +364,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
@@ -367,6 +398,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
},
|
||||
/* DOTPROD GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
@@ -381,7 +418,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
},
|
||||
/* .lhs_info = */ {
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
|
||||
@@ -84,8 +84,11 @@ struct rhs_packing_info {
|
||||
|
||||
struct ggml_kleidiai_kernels {
|
||||
kernel_info gemm;
|
||||
lhs_packing_info gemm_lhs_info;
|
||||
|
||||
kernel_info gemv;
|
||||
lhs_packing_info lhs_info;
|
||||
lhs_packing_info gemv_lhs_info;
|
||||
|
||||
rhs_packing_info rhs_info;
|
||||
|
||||
cpu_feature required_cpu;
|
||||
|
||||
@@ -123,7 +123,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
}
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
||||
GGML_ASSERT(kernels);
|
||||
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
bool is_gemv = op->src[1]->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
|
||||
size_t k = op->src[0]->ne[0];
|
||||
size_t n = op->src[0]->ne[1];
|
||||
@@ -134,9 +136,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
|
||||
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
|
||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
|
||||
size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
|
||||
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
||||
k * n * sizeof(float) + n * sizeof(float);
|
||||
} else {
|
||||
@@ -173,7 +175,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
|
||||
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
bool is_gemv = src1->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
GGML_ASSERT(kernel);
|
||||
|
||||
const int nth = params->nth;
|
||||
@@ -198,7 +202,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
|
||||
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
|
||||
|
||||
const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
|
||||
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
|
||||
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
|
||||
const size_t kxn_size = k * n * sizeof(float);
|
||||
const size_t bias_size = n * sizeof(float);
|
||||
@@ -229,12 +233,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
||||
|
||||
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
|
||||
|
||||
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
|
||||
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
|
||||
|
||||
variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
||||
variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -306,8 +310,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
|
||||
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = &kernels->lhs_info;
|
||||
bool is_gemv = src1->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
|
||||
GGML_ASSERT(kernel);
|
||||
|
||||
|
||||
@@ -2169,94 +2169,117 @@ class tinyBLAS_Q0_PPC {
|
||||
class tinyBLAS_PPC {
|
||||
public:
|
||||
tinyBLAS_PPC(int64_t k,
|
||||
const float *A, int64_t lda,
|
||||
const float *B, int64_t ldb,
|
||||
float *C, int64_t ldc,
|
||||
const float * A, int64_t lda,
|
||||
const float * B, int64_t ldb,
|
||||
float * C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
mnpack(0, m, 0, n);
|
||||
int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
|
||||
if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
|
||||
matmul_tiled(m, n, mc, nc, kc);
|
||||
} else {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
|
||||
|
||||
inline void vector_permute_store_4(vector float *src, float *vecOffset) {
|
||||
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
t1 = vec_mergeh(src[0], src[1]);
|
||||
t2 = vec_mergeh(src[2], src[3]);
|
||||
t3 = vec_mergel(src[0], src[1]);
|
||||
t4 = vec_mergel(src[2], src[3]);
|
||||
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t1, t2, 3);
|
||||
t7 = vec_xxpermdi(t3, t4, 0);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset + 4);
|
||||
vec_xst(t7, 0, vecOffset + 8);
|
||||
vec_xst(t8, 0, vecOffset + 12);
|
||||
}
|
||||
|
||||
inline void vector_permute_store_8(vector float *src, float *vecOffset) {
|
||||
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
t1 = vec_mergeh(src[0], src[1]);
|
||||
t2 = vec_mergeh(src[2], src[3]);
|
||||
t3 = vec_mergeh(src[4], src[5]);
|
||||
t4 = vec_mergeh(src[6], src[7]);
|
||||
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset + 4);
|
||||
vec_xst(t7, 0, vecOffset + 8);
|
||||
vec_xst(t8, 0, vecOffset + 12);
|
||||
|
||||
t1 = vec_mergel(src[0], src[1]);
|
||||
t2 = vec_mergel(src[2], src[3]);
|
||||
t3 = vec_mergel(src[4], src[5]);
|
||||
t4 = vec_mergel(src[6], src[7]);
|
||||
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
|
||||
vec_xst(t5, 0, vecOffset + 16);
|
||||
vec_xst(t6, 0, vecOffset + 20);
|
||||
vec_xst(t7, 0, vecOffset + 24);
|
||||
vec_xst(t8, 0, vecOffset + 28);
|
||||
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
|
||||
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
|
||||
*c_ptr += *((float *)&vec_C[I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void vector_permute_store_4(vector float * src, float * vecOffset) {
|
||||
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
t1 = vec_mergeh(src[0], src[1]);
|
||||
t2 = vec_mergeh(src[2], src[3]);
|
||||
t3 = vec_mergel(src[0], src[1]);
|
||||
t4 = vec_mergel(src[2], src[3]);
|
||||
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t1, t2, 3);
|
||||
t7 = vec_xxpermdi(t3, t4, 0);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset + 4);
|
||||
vec_xst(t7, 0, vecOffset + 8);
|
||||
vec_xst(t8, 0, vecOffset + 12);
|
||||
}
|
||||
|
||||
inline void vector_permute_store_8(vector float * src, float * vecOffset) {
|
||||
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
t1 = vec_mergeh(src[0], src[1]);
|
||||
t2 = vec_mergeh(src[2], src[3]);
|
||||
t3 = vec_mergeh(src[4], src[5]);
|
||||
t4 = vec_mergeh(src[6], src[7]);
|
||||
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset + 4);
|
||||
vec_xst(t7, 0, vecOffset + 8);
|
||||
vec_xst(t8, 0, vecOffset + 12);
|
||||
|
||||
t1 = vec_mergel(src[0], src[1]);
|
||||
t2 = vec_mergel(src[2], src[3]);
|
||||
t3 = vec_mergel(src[4], src[5]);
|
||||
t4 = vec_mergel(src[6], src[7]);
|
||||
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
|
||||
vec_xst(t5, 0, vecOffset + 16);
|
||||
vec_xst(t6, 0, vecOffset + 20);
|
||||
vec_xst(t7, 0, vecOffset + 24);
|
||||
vec_xst(t8, 0, vecOffset + 28);
|
||||
}
|
||||
|
||||
void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
|
||||
int64_t i, j;
|
||||
float * aoffsets[8];
|
||||
float *aoffset = NULL, *boffset = NULL;
|
||||
float * aoffset = NULL, * boffset = NULL;
|
||||
__vector_pair arr[8];
|
||||
vector float c[8][2] = {0};
|
||||
vector float c1[8] = {0};
|
||||
vector float c2[8] = {0};
|
||||
aoffset = const_cast<float*>(a);
|
||||
aoffset = const_cast<float *>(a);
|
||||
boffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
|
||||
do {
|
||||
aoffsets[0] = aoffset;
|
||||
for (int it = 1; it< 8; it++)
|
||||
for (int it = 1; it < 8; it++)
|
||||
aoffsets[it] = aoffsets[it-1] + lda;
|
||||
aoffset += 8 * lda;
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
for (int it = 0; it< 8; it++) {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
@@ -2264,11 +2287,14 @@ class tinyBLAS_PPC {
|
||||
}
|
||||
|
||||
vector_permute_store_8(c1, boffset);
|
||||
vector_permute_store_8(c2, boffset+32);
|
||||
for (int it = 0; it < 4; it++)
|
||||
aoffsets[it] = aoffsets[it] + 8*lda;
|
||||
vector_permute_store_8(c2, boffset + 32);
|
||||
boffset += 64;
|
||||
i--;
|
||||
if (i > 0) {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
aoffsets[it] = aoffsets[it] + 8;
|
||||
}
|
||||
}
|
||||
} while(i > 0);
|
||||
}
|
||||
if (cols & 4) {
|
||||
@@ -2295,9 +2321,9 @@ class tinyBLAS_PPC {
|
||||
c2[it] = c[it][1];
|
||||
}
|
||||
vector_permute_store_4(c1, boffset);
|
||||
vector_permute_store_4(c2, boffset+16);
|
||||
vector_permute_store_4(c2, boffset + 16);
|
||||
for (int it = 0; it < 4; it++)
|
||||
aoffsets[it] += 8*lda;
|
||||
aoffsets[it] += 8 * lda;
|
||||
boffset += 32;
|
||||
i--;
|
||||
} while(i > 0);
|
||||
@@ -2325,15 +2351,15 @@ class tinyBLAS_PPC {
|
||||
vec_t vec_A[4], vec_B[4], vec_C[4];
|
||||
acc_t acc_0;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
for (int l = 0; l < k; l+=4) {
|
||||
packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
|
||||
packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
||||
for (int l = 0; l < k; l += 4) {
|
||||
packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
|
||||
packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
save_acc(&acc_0, ii, jj);
|
||||
}
|
||||
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
@@ -2341,9 +2367,9 @@ class tinyBLAS_PPC {
|
||||
acc_t acc_0, acc_1;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
for (int64_t l = 0; l < k; l+=4) {
|
||||
packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
|
||||
packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
|
||||
for (int64_t l = 0; l < k; l += 4) {
|
||||
packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
|
||||
packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
|
||||
@@ -2353,8 +2379,8 @@ class tinyBLAS_PPC {
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
SAVE_ACC(&acc_1, ii, jj+4);
|
||||
save_acc(&acc_0, ii, jj);
|
||||
save_acc(&acc_1, ii, jj + 4);
|
||||
}
|
||||
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
@@ -2362,9 +2388,9 @@ class tinyBLAS_PPC {
|
||||
acc_t acc_0, acc_1;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
for (int64_t l = 0; l < k; l+=4) {
|
||||
packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
|
||||
packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
||||
for (int64_t l = 0; l < k; l += 4) {
|
||||
packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
|
||||
packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
|
||||
@@ -2374,8 +2400,8 @@ class tinyBLAS_PPC {
|
||||
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
SAVE_ACC(&acc_1, ii+4, jj);
|
||||
save_acc(&acc_0, ii, jj);
|
||||
save_acc(&acc_1, ii + 4, jj);
|
||||
}
|
||||
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
@@ -2386,19 +2412,96 @@ class tinyBLAS_PPC {
|
||||
__builtin_mma_xxsetaccz(&acc_2);
|
||||
__builtin_mma_xxsetaccz(&acc_3);
|
||||
for (int l = 0; l < k; l+=8) {
|
||||
packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
|
||||
packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
|
||||
packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
|
||||
packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
|
||||
for(int x = 0; x < 16; x+=2) {
|
||||
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
|
||||
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
|
||||
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
|
||||
}
|
||||
}
|
||||
save_acc(&acc_0, ii, jj);
|
||||
save_acc(&acc_1, ii, jj + 4);
|
||||
save_acc(&acc_2, ii + 4, jj);
|
||||
save_acc(&acc_3, ii + 4, jj + 4);
|
||||
}
|
||||
|
||||
inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
|
||||
for (int x = 0; x < 16; x += 2) {
|
||||
__builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
|
||||
__builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
|
||||
__builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
|
||||
__builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
|
||||
__builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
|
||||
__builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
|
||||
__builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
|
||||
__builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
|
||||
for (int64_t i = 0; i < mc; i += 16) {
|
||||
int A_base_addr = (mc / 8) * (i / 8) * 16;
|
||||
for (int64_t j = 0; j < nc; j += 8) {
|
||||
int B_base_addr = (nc / 8) * (j / 8) * 16;
|
||||
acc_t acc[8];
|
||||
vec_t A0_block[16]; vec_t A1_block[16];
|
||||
for (int x = 0; x < 8; x++)
|
||||
__builtin_mma_xxsetaccz(&acc[x]);
|
||||
for (int64_t l = 0; l < kc; l += 8) {
|
||||
int A0_block_idx = A_base_addr + (l / 8) * 16;
|
||||
int A1_block_idx = A0_block_idx + (mc / 8) * 16;
|
||||
int B_block_idx = B_base_addr + (l / 8) * 16;
|
||||
vec_t* A0_block = &vec_A[A0_block_idx];
|
||||
vec_t* A1_block = &vec_A[A1_block_idx];
|
||||
vec_t* B_block = &vec_B[B_block_idx];
|
||||
MMA_16x8(A0_block, A1_block, B_block, acc);
|
||||
}
|
||||
if (kk == 0) {
|
||||
save_acc(&acc[0], ii + i, jj + j);
|
||||
save_acc(&acc[1], ii + i, jj + j + 4);
|
||||
save_acc(&acc[2], ii + i + 4, jj + j);
|
||||
save_acc(&acc[3], ii + i + 4, jj + j + 4);
|
||||
save_acc(&acc[4], ii + i + 8, jj + j);
|
||||
save_acc(&acc[5], ii + i + 8, jj + j + 4);
|
||||
save_acc(&acc[6], ii + i + 12, jj + j);
|
||||
save_acc(&acc[7], ii + i + 12, jj + j + 4);
|
||||
} else {
|
||||
add_save_acc(&acc[0], ii + i, jj + j);
|
||||
add_save_acc(&acc[1], ii + i, jj + j + 4);
|
||||
add_save_acc(&acc[2], ii + i + 4, jj + j);
|
||||
add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
|
||||
add_save_acc(&acc[4], ii + i + 8, jj + j);
|
||||
add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
|
||||
add_save_acc(&acc[6], ii + i + 12, jj + j);
|
||||
add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
||||
int64_t ytiles = m / mc;
|
||||
int64_t xtiles = n / nc;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles) {
|
||||
end = tiles;
|
||||
}
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
vec_t A_pack[kc * mc / 4];
|
||||
vec_t B_pack[kc * nc / 4];
|
||||
packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
|
||||
packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
|
||||
KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
SAVE_ACC(&acc_1, ii, jj+4);
|
||||
SAVE_ACC(&acc_2, ii+4, jj);
|
||||
SAVE_ACC(&acc_3, ii+4, jj+4);
|
||||
}
|
||||
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
@@ -2406,35 +2509,35 @@ class tinyBLAS_PPC {
|
||||
int n_rem = MIN(n - n0, 8);
|
||||
int mc = 0, nc = 0;
|
||||
if (m_rem >= 8 && n_rem >= 8) {
|
||||
mc = 8;
|
||||
nc = 8;
|
||||
gemm<8, 8>(m0, m, n0, n);
|
||||
mc = 8;
|
||||
nc = 8;
|
||||
gemm<8, 8>(m0, m, n0, n);
|
||||
} else if (m_rem >= 4 && n_rem >= 8) {
|
||||
mc = 4;
|
||||
nc = 8;
|
||||
gemm<4, 8>(m0, m, n0, n);
|
||||
mc = 4;
|
||||
nc = 8;
|
||||
gemm<4, 8>(m0, m, n0, n);
|
||||
} else if (m_rem >= 8 && n_rem >= 4) {
|
||||
mc = 8;
|
||||
nc = 4;
|
||||
gemm<8, 4>(m0, m, n0, n);
|
||||
mc = 8;
|
||||
nc = 4;
|
||||
gemm<8, 4>(m0, m, n0, n);
|
||||
} else if (m_rem >= 4 && n_rem >= 4) {
|
||||
mc = 4;
|
||||
nc = 4;
|
||||
gemm<4, 4>(m0, m, n0, n);
|
||||
mc = 4;
|
||||
nc = 4;
|
||||
gemm<4, 4>(m0, m, n0, n);
|
||||
} else {
|
||||
mc = (m_rem >= 4) ? 4 : m_rem;
|
||||
nc = (n_rem >= 4) ? 4 : n_rem;
|
||||
if (mc == 0 || nc == 0)
|
||||
return;
|
||||
return;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
}
|
||||
int64_t mp = m0 + ((m - m0) / mc) * mc;
|
||||
int64_t np = n0 + ((n - n0) / nc) * nc;
|
||||
mnpack(mp, m, n0, np);
|
||||
mnpack(m0, m, np, n);
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
@@ -2449,30 +2552,30 @@ class tinyBLAS_PPC {
|
||||
vec_t vec_C[4];
|
||||
acc_t acc_0;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
vec_t vec_A[4] {0}, vec_B[4] = {0};
|
||||
for (int l=0; l<k; l+=4) {
|
||||
vec_t vec_A[4] = {0}, vec_B[4] = {0};
|
||||
for (int l = 0; l < k; l += 4) {
|
||||
/* 'GEMV Forwarding' concept is used in first two conditional loops.
|
||||
* when one of the matrix has a single row/column, the elements are
|
||||
* broadcasted, instead of using packing routine to prepack the
|
||||
* matrix elements.
|
||||
*/
|
||||
if (RM == 1) {
|
||||
float* a = const_cast<float*>(A+(ii)*lda+l);
|
||||
packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
|
||||
float * a = const_cast<float *>(A + (ii) * lda + l);
|
||||
packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
|
||||
vec_A[0] = (vec_t)vec_xl(0,a);
|
||||
vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
|
||||
vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
|
||||
vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
|
||||
vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
|
||||
vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
|
||||
vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
|
||||
} else if (RN == 1) {
|
||||
packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
|
||||
float* b = const_cast<float*>(B+(jj)*ldb+l);
|
||||
packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
|
||||
float * b = const_cast<float *>(B + (jj) * ldb + l);
|
||||
vec_B[0] = (vec_t)vec_xl(0,b);
|
||||
vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
|
||||
vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
|
||||
vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
|
||||
vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
|
||||
vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
|
||||
vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
|
||||
} else {
|
||||
packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
|
||||
packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
|
||||
packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
|
||||
packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
|
||||
}
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
||||
@@ -2482,12 +2585,27 @@ class tinyBLAS_PPC {
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
|
||||
*((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int RM, int RN>
|
||||
inline void kernel(int64_t ii, int64_t jj) {
|
||||
if constexpr(RM == 4 && RN == 4) {
|
||||
KERNEL_4x4(ii, jj);
|
||||
} else if constexpr(RM == 4 && RN == 8) {
|
||||
KERNEL_4x8(ii, jj);
|
||||
} else if constexpr(RM == 8 && RN == 4) {
|
||||
KERNEL_8x4(ii, jj);
|
||||
} else if constexpr(RM == 8 && RN == 8) {
|
||||
KERNEL_8x8(ii, jj);
|
||||
} else {
|
||||
static_assert(false, "RN/RM values not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <int RM, int RN>
|
||||
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
@@ -2496,27 +2614,18 @@ class tinyBLAS_PPC {
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (RM == 4 && RN == 4) {
|
||||
kernel = &tinyBLAS_PPC::KERNEL_4x4;
|
||||
} else if (RM == 4 && RN == 8) {
|
||||
kernel = &tinyBLAS_PPC::KERNEL_4x8;
|
||||
} else if (RM == 8 && RN == 4) {
|
||||
kernel = &tinyBLAS_PPC::KERNEL_8x4;
|
||||
} else if (RM == 8 && RN == 8) {
|
||||
kernel = &tinyBLAS_PPC::KERNEL_8x8;
|
||||
}
|
||||
if (end > tiles)
|
||||
end = tiles;
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = m0 + job / xtiles * RM;
|
||||
int64_t jj = n0 + job % xtiles * RN;
|
||||
(this->*kernel)(ii, jj);
|
||||
kernel<RM, RN>(ii, jj);
|
||||
}
|
||||
}
|
||||
|
||||
const float *const A;
|
||||
const float *const B;
|
||||
float *C;
|
||||
const float * const A;
|
||||
const float * const B;
|
||||
float * C;
|
||||
const int64_t k;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
|
||||
+158
-12
@@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
|
||||
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_3d
|
||||
|
||||
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
|
||||
const ggml_tensor * kernel,
|
||||
const ggml_tensor * src,
|
||||
ggml_tensor * dst,
|
||||
ggml_type kernel_type) {
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(kernel));
|
||||
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(kernel->type == kernel_type);
|
||||
|
||||
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
||||
|
||||
const int32_t s0 = dst->op_params[0];
|
||||
const int32_t s1 = dst->op_params[1];
|
||||
const int32_t s2 = dst->op_params[2];
|
||||
const int32_t p0 = dst->op_params[3];
|
||||
const int32_t p1 = dst->op_params[4];
|
||||
const int32_t p2 = dst->op_params[5];
|
||||
const int32_t d0 = dst->op_params[6];
|
||||
const int32_t d1 = dst->op_params[7];
|
||||
const int32_t d2 = dst->op_params[8];
|
||||
const int32_t c = dst->op_params[9];
|
||||
const int32_t n = dst->op_params[10];
|
||||
const int32_t oc = dst->op_params[11];
|
||||
|
||||
const int64_t src_w = src->ne[0];
|
||||
const int64_t src_h = src->ne[1];
|
||||
const int64_t src_d = src->ne[2];
|
||||
const int64_t knl_w = kernel->ne[0];
|
||||
const int64_t knl_h = kernel->ne[1];
|
||||
const int64_t knl_d = kernel->ne[2];
|
||||
const int64_t dst_w = dst->ne[0];
|
||||
const int64_t dst_h = dst->ne[1];
|
||||
const int64_t dst_d = dst->ne[2];
|
||||
|
||||
const float * src_data = (float *) src->data;
|
||||
void * knl_data = kernel->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
|
||||
const int64_t knl_n_total = knl_n_per_channel * c;
|
||||
const int64_t patch_total = n * dst_w * dst_h * dst_d;
|
||||
|
||||
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
|
||||
const int64_t batch_size = params->wsize / space_per_patch;
|
||||
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
||||
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
||||
|
||||
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
||||
|
||||
void * tmp = params->wdata;
|
||||
|
||||
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
||||
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
||||
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
|
||||
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
|
||||
|
||||
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
||||
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
||||
|
||||
for (int64_t p = patch_start; p < patch_end; ++p) {
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
|
||||
|
||||
for (int64_t ic = 0; ic < c; ++ic) {
|
||||
for (int64_t kz = 0; kz < knl_d; ++kz) {
|
||||
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
||||
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
||||
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
||||
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
||||
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
||||
|
||||
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
|
||||
|
||||
float src_val;
|
||||
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
||||
src_val = 0.0f;
|
||||
} else {
|
||||
const int64_t cn_idx = batch_idx * c + ic;
|
||||
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
|
||||
src_val = *src_ptr;
|
||||
}
|
||||
|
||||
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
||||
if (kernel_type == GGML_TYPE_F32) {
|
||||
*(float *)element_ptr = src_val;
|
||||
} else if (kernel_type == GGML_TYPE_F16) {
|
||||
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
|
||||
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t permute_start = params->ith * permute_per_thread;
|
||||
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
|
||||
|
||||
for (int64_t i = permute_start; i < permute_end; ++i) {
|
||||
const int64_t p = patch_start_batch + i;
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
for (int64_t ioc = 0; ioc < oc; ++ioc) {
|
||||
const float value = gemm_output[i * oc + ioc];
|
||||
const int64_t ocn_idx = batch_idx * oc + ioc;
|
||||
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
|
||||
*dst_ptr = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_conv_3d(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_transpose_2d
|
||||
|
||||
void ggml_compute_forward_conv_transpose_2d(
|
||||
@@ -8861,8 +9003,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||
// allows optimizing the modulo since n_group should be a power of 2
|
||||
GGML_ASSERT((ng & -ng) == ng);
|
||||
GGML_ASSERT(nh % ng == 0);
|
||||
|
||||
// heads per thread
|
||||
const int dh = (nh + nth - 1)/nth;
|
||||
@@ -8893,6 +9034,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
||||
const float dA = expf(dt_soft_plus * A[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
// dim
|
||||
for (int i1 = 0; i1 < nr; ++i1) {
|
||||
@@ -8915,8 +9057,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// TODO: maybe unroll more?
|
||||
for (int j = 0; j < 1; j++) {
|
||||
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
||||
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
||||
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
||||
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
|
||||
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
|
||||
|
||||
t0 = GGML_F32_VEC_MUL(t0, adA);
|
||||
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
||||
@@ -8930,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
}
|
||||
|
||||
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
// todo: RVV implementation
|
||||
const int np = 0;
|
||||
#else
|
||||
const int np = (nc & ~(GGML_F32_STEP - 1));
|
||||
|
||||
@@ -8945,8 +9090,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
||||
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
||||
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
||||
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
|
||||
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
|
||||
|
||||
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
||||
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
||||
@@ -8968,7 +9113,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// d_state
|
||||
for (int i0 = np; i0 < nc; ++i0) {
|
||||
const int i = i0 + ii*nc;
|
||||
const int ig = i0 + (h & (ng - 1))*nc;
|
||||
const int ig = i0 + g*nc;
|
||||
// state = prev_state * dA + dB * x
|
||||
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
||||
// y = rowwise_dotprod(state, C)
|
||||
@@ -8985,6 +9130,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
// dim
|
||||
for (int i1 = 0; i1 < nr; ++i1) {
|
||||
@@ -8999,8 +9145,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// TODO: what happens when (d_state % svcntw()) != 0?
|
||||
for (int64_t k = 0; k < nc; k += svcntw()) {
|
||||
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
||||
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
|
||||
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
|
||||
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
|
||||
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
|
||||
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
||||
|
||||
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
||||
@@ -9020,7 +9166,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||
// d_state
|
||||
for (int i0 = 0; i0 < nc; ++i0) {
|
||||
const int i = i0 + ii*nc;
|
||||
const int ig = i0 + (h & (ng - 1))*nc;
|
||||
const int ig = i0 + g*nc;
|
||||
// state = prev_state * dA + dB * x
|
||||
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
||||
// y = rowwise_dotprod(state, C)
|
||||
@@ -9881,8 +10027,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
||||
int64_t h_stride_2d = head_size * head_size;
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
// scalar Route to scalar implementation //TODO: Write SVE code
|
||||
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
|
||||
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
int64_t t_offset = t * t_stride;
|
||||
int64_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
|
||||
@@ -70,6 +70,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
|
||||
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
#include <riscv_vector.h>
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -94,24 +98,15 @@ extern "C" {
|
||||
}
|
||||
#elif defined(__riscv) && defined(__riscv_zfhmin)
|
||||
static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) {
|
||||
float f;
|
||||
__asm__(
|
||||
"fmv.h.x %[f], %[h]\n\t"
|
||||
"fcvt.s.h %[f], %[f]"
|
||||
: [f] "=&f" (f)
|
||||
: [h] "r" (h)
|
||||
);
|
||||
return f;
|
||||
_Float16 hf;
|
||||
memcpy(&hf, &h, sizeof(ggml_fp16_t));
|
||||
return hf;
|
||||
}
|
||||
|
||||
static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) {
|
||||
ggml_fp16_t res;
|
||||
__asm__(
|
||||
"fcvt.h.s %[f], %[f]\n\t"
|
||||
"fmv.x.h %[h], %[f]"
|
||||
: [h] "=&r" (res)
|
||||
: [f] "f" (f)
|
||||
);
|
||||
_Float16 hf = (_Float16)f;
|
||||
memcpy(&res, &hf, sizeof(ggml_fp16_t));
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -220,6 +215,47 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||
#define GGML_F32_VEC_MUL GGML_F32xt_MUL
|
||||
#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE
|
||||
|
||||
// F16 SVE
|
||||
#define DEFAULT_PG32 svptrue_b32()
|
||||
#define DEFAULT_PG16 svptrue_b16()
|
||||
|
||||
#define GGML_F32Cxt svfloat16_t
|
||||
#define GGML_F32Cxt_ZERO svdup_n_f16(0.0f)
|
||||
#define GGML_F32Cxt_SET1(x) svdup_n_f16(x)
|
||||
#define GGML_F32Cxt_LOAD(p) svld1_f16(DEFAULT_PG16, (const __fp16 *)(p))
|
||||
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
|
||||
|
||||
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
|
||||
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
|
||||
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
|
||||
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
|
||||
|
||||
#define GGML_F16x_VEC GGML_F32Cxt
|
||||
#define GGML_F16x_VEC_ZERO GGML_F32Cxt_ZERO
|
||||
#define GGML_F16x_VEC_SET1 GGML_F32Cxt_SET1
|
||||
#define GGML_F16x_VEC_LOAD(p, i) GGML_F32Cxt_LOAD(p)
|
||||
#define GGML_F16x_VEC_STORE(p, r, i) GGML_F32Cxt_STORE((__fp16 *)(p), r)
|
||||
#define GGML_F16x_VEC_FMA GGML_F32Cxt_FMA
|
||||
#define GGML_F16x_VEC_ADD GGML_F32Cxt_ADD
|
||||
#define GGML_F16x_VEC_MUL GGML_F32Cxt_MUL
|
||||
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
|
||||
|
||||
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
|
||||
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
|
||||
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
|
||||
{ \
|
||||
sum1 = svadd_f16_x(pg16, sum1, sum2); \
|
||||
sum3 = svadd_f16_x(pg16, sum3, sum4); \
|
||||
sum1 = svadd_f16_x(pg16, sum1, sum3); \
|
||||
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
|
||||
(res) = (ggml_float) sum_f16; \
|
||||
}
|
||||
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
|
||||
|
||||
// F16 NEON
|
||||
|
||||
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
|
||||
@@ -1170,6 +1206,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
|
||||
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
|
||||
// compatible with vlen >= 128
|
||||
|
||||
#define GGML_SIMD
|
||||
|
||||
// F32
|
||||
|
||||
#define GGML_F32_STEP 16
|
||||
#define GGML_F32_EPR 4
|
||||
|
||||
#define GGML_F32x4 vfloat32m1_t
|
||||
#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR)
|
||||
#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR)
|
||||
#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR)
|
||||
#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR)
|
||||
#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR)
|
||||
#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR)
|
||||
#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR)
|
||||
|
||||
#define GGML_F32_VEC GGML_F32x4
|
||||
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
||||
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
|
||||
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
|
||||
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
|
||||
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
|
||||
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
|
||||
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
#endif
|
||||
|
||||
// GGML_F32_ARR / GGML_F16_ARR
|
||||
|
||||
+123
-19
@@ -84,6 +84,16 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
||||
}
|
||||
// reduce sum1,sum2 to sum1
|
||||
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
||||
for (int i = 0, avl; i < n; i += avl) {
|
||||
avl = __riscv_vsetvl_e32m8(n - i);
|
||||
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
|
||||
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
|
||||
vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl);
|
||||
vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl);
|
||||
}
|
||||
sumf += __riscv_vfmv_f_s_f32m1_f32(vsum);
|
||||
#else
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
@@ -197,33 +207,97 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
|
||||
|
||||
ggml_float sumf = 0.0;
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
|
||||
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8; //get vector length
|
||||
const int ggml_f16_epr = sve_register_length / 16; // running when 16
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
const int np= (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t sum1 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum2 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum3 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum4 = svdup_n_f16(0.0f);
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
|
||||
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
|
||||
sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
|
||||
|
||||
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
|
||||
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
|
||||
sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
|
||||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
|
||||
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
|
||||
sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
|
||||
|
||||
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
|
||||
|
||||
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
|
||||
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
|
||||
sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
|
||||
|
||||
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
|
||||
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
|
||||
sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
|
||||
|
||||
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
|
||||
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
|
||||
sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
|
||||
|
||||
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
|
||||
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
|
||||
sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
|
||||
}
|
||||
}
|
||||
|
||||
// reduce sum0..sum3 to sum0
|
||||
GGML_F16_VEC_REDUCE(sumf, sum);
|
||||
const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
|
||||
for (int k = np; k < np2; k += ggml_f16_epr) {
|
||||
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
|
||||
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
|
||||
sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b16(np2, n);
|
||||
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
|
||||
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||
|
||||
// if you hit this, you are likely running outside the FP range
|
||||
assert(!isnan(sumf) && !isinf(sumf));
|
||||
sum1 = svmad_f16_x(pg, hx, hy, sum1);
|
||||
}
|
||||
GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
|
||||
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// reduce sum0..sum3 to sum0
|
||||
GGML_F16_VEC_REDUCE(sumf, sum);
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
// if you hit this, you are likely running outside the FP range
|
||||
assert(!isnan(sumf) && !isinf(sumf));
|
||||
#endif
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
@@ -247,6 +321,12 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
||||
for (; i + 3 < n; i += 4) {
|
||||
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
|
||||
}
|
||||
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
||||
const int vlen = svcntw();
|
||||
for (; i < n; i += vlen) {
|
||||
const svbool_t pg = svwhilelt_b32_s32(i, n);
|
||||
svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i)));
|
||||
}
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
||||
@@ -271,6 +351,12 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
|
||||
for (; i + 3 < n; i += 4) {
|
||||
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
|
||||
}
|
||||
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
||||
const int vlen = svcntw();
|
||||
for (; i < n; i += vlen) {
|
||||
const svbool_t pg = svwhilelt_b32_s32(i, n);
|
||||
svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i)));
|
||||
}
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
|
||||
@@ -318,6 +404,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
|
||||
#endif
|
||||
sum += (ggml_float)_mm_cvtss_f32(val);
|
||||
}
|
||||
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
||||
const int vlen = svcntw();
|
||||
for (; i < n; i += vlen) {
|
||||
const svbool_t pg = svwhilelt_b32_s32(i, n);
|
||||
svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i),
|
||||
svdup_n_f32_x(pg, max)));
|
||||
svst1_f32(pg, y + i, val);
|
||||
sum += (ggml_float)svaddv_f32(pg, val);
|
||||
}
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|
||||
@@ -325,6 +420,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
|
||||
vst1q_f32(y + i, val);
|
||||
sum += (ggml_float)vaddvq_f32(val);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
|
||||
for (int avl; i < n; i += avl) {
|
||||
avl = __riscv_vsetvl_e32m2(n - i);
|
||||
vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
|
||||
__riscv_vse32_v_f32m2(&y[i], val, avl);
|
||||
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
|
||||
}
|
||||
return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
float val = expf(x[i] - max);
|
||||
|
||||
+393
-52
@@ -119,36 +119,149 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
}
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
|
||||
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16; // running when 16
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
svfloat16_t sum_00 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_01 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_02 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_03 = svdup_n_f16(0.0f);
|
||||
|
||||
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
|
||||
svfloat16_t sum_10 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_11 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_12 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_13 = svdup_n_f16(0.0f);
|
||||
|
||||
sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
|
||||
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
|
||||
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
|
||||
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
|
||||
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
|
||||
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
|
||||
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
|
||||
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
|
||||
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
|
||||
|
||||
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
|
||||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
|
||||
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
|
||||
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
|
||||
ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3);
|
||||
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4);
|
||||
ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3);
|
||||
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4);
|
||||
|
||||
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
|
||||
|
||||
ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4);
|
||||
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5);
|
||||
ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4);
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5);
|
||||
|
||||
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
|
||||
|
||||
ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5);
|
||||
|
||||
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6);
|
||||
ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5);
|
||||
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6);
|
||||
|
||||
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
|
||||
|
||||
ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6);
|
||||
|
||||
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7);
|
||||
ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6);
|
||||
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7);
|
||||
|
||||
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
|
||||
|
||||
ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7);
|
||||
|
||||
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8);
|
||||
ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7);
|
||||
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8);
|
||||
}
|
||||
|
||||
const int np2 = (n & ~(ggml_f16_epr - 1));
|
||||
for (int k = np; k < np2; k += ggml_f16_epr) {
|
||||
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
|
||||
|
||||
svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0);
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry);
|
||||
rx = GGML_F16x_VEC_LOAD(x[1] + k, 0);
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry);
|
||||
}
|
||||
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b16(np2, n);
|
||||
svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
|
||||
svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
|
||||
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||
|
||||
sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);
|
||||
sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);
|
||||
}
|
||||
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
|
||||
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
// todo: RVV impl
|
||||
for (int i = 0; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
|
||||
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
|
||||
|
||||
sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reduce sum0..sum3 to sum0
|
||||
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
|
||||
GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
// reduce sum0..sum3 to sum0
|
||||
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
|
||||
GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
@@ -243,6 +356,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
||||
|
||||
svst1_f32(pg, y + np2, ay1);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
for (int i = 0, avl; i < n; i += avl) {
|
||||
avl = __riscv_vsetvl_e32m8(n - i);
|
||||
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
|
||||
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
|
||||
vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl);
|
||||
__riscv_vse32_v_f32m8(&y[i], ny, avl);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
@@ -276,27 +397,112 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
||||
|
||||
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr;
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
const int np= (n & ~(ggml_f16_step - 1));
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
||||
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
|
||||
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
|
||||
|
||||
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
|
||||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
|
||||
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
|
||||
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
|
||||
|
||||
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
|
||||
|
||||
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
|
||||
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
|
||||
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
|
||||
|
||||
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
|
||||
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
|
||||
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
|
||||
|
||||
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
|
||||
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
|
||||
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
|
||||
|
||||
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
|
||||
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
|
||||
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
|
||||
|
||||
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
|
||||
}
|
||||
}
|
||||
const int np2 = (n & ~(ggml_f16_epr - 1));
|
||||
for (int k = np; k < np2; k += ggml_f16_epr) {
|
||||
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
|
||||
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
|
||||
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
GGML_F16x_VEC_STORE(y + k, ry, 0);
|
||||
}
|
||||
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b16(np2, n);
|
||||
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
|
||||
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
|
||||
hy = svmad_f16_x(pg, hx, vx, hy);
|
||||
svst1_f16(pg, (__fp16 *)(y + np2), hy);
|
||||
}
|
||||
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
// todo: RVV impl
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
@@ -324,6 +530,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
|
||||
y[i] += x[k][i]*v[k][0];
|
||||
}
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
for (int i = 0, avl; i < n; i += avl) {
|
||||
avl = __riscv_vsetvl_e32m8(n - i);
|
||||
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
|
||||
for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) {
|
||||
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl);
|
||||
ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl);
|
||||
}
|
||||
__riscv_vse32_v_f32m8(&y[i], ay, avl);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
@@ -375,6 +591,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
for (int i = 0, avl; i < n; i += avl) {
|
||||
avl = __riscv_vsetvl_e32m8(n - i);
|
||||
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
|
||||
vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);
|
||||
vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);
|
||||
__riscv_vse32_v_f32m8(&y[i], ny, avl);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
@@ -436,6 +660,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
ay1 = svmul_f32_m(pg, ay1, vx);
|
||||
svst1_f32(pg, y + np, ay1);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
for (int i = 0, avl; i < n; i += avl) {
|
||||
avl = __riscv_vsetvl_e32m8(n - i);
|
||||
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
|
||||
vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl);
|
||||
__riscv_vse32_v_f32m8(&y[i], ny, avl);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
@@ -467,25 +698,59 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
|
||||
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const int sve_register_length = svcntb() * 8;
|
||||
const int ggml_f16_epr = sve_register_length / 16;
|
||||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t ay1, ay2;
|
||||
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
|
||||
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
|
||||
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
|
||||
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
|
||||
}
|
||||
}
|
||||
// leftovers
|
||||
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
|
||||
if (np < n) {
|
||||
svbool_t pg = svwhilelt_b16(np, n);
|
||||
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
|
||||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
// todo: RVV impl
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
@@ -737,7 +1002,39 @@ https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/sr
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_NEON) && defined(__aarch64__)
|
||||
#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
|
||||
|
||||
inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) {
|
||||
const svfloat32_t r = svdup_n_f32_x(pg, 0x1.8p23f);
|
||||
const svfloat32_t z = svmla_n_f32_x(pg, r, x, 0x1.715476p+0f);
|
||||
const svfloat32_t n = svsub_f32_x(pg, z, r);
|
||||
const svfloat32_t b = svmls_n_f32_x(pg, svmls_n_f32_x(pg, x, n, 0x1.62e4p-1f), n, 0x1.7f7d1cp-20f);
|
||||
const svuint32_t e = svlsl_n_u32_x(pg, svreinterpret_u32_f32(z), 23);
|
||||
const svfloat32_t k = svreinterpret_f32_u32(svadd_u32_x(pg, e, svreinterpret_u32_f32(svdup_n_f32_x(pg, 1))));
|
||||
const svbool_t c = svacgt_n_f32(pg, n, 126);
|
||||
const svfloat32_t u = svmul_f32_x(pg, b, b);
|
||||
const svfloat32_t j = svmla_f32_x(pg,
|
||||
svmul_n_f32_x(pg, b, 0x1.ffffecp-1f),
|
||||
svmla_f32_x(pg, svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.fffdb6p-2f), svdup_n_f32_x(pg, 0x1.555e66p-3f), b),
|
||||
svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.573e2ep-5f), svdup_n_f32_x(pg, 0x1.0e4020p-7f), b), u), u);
|
||||
const svuint32_t d = svdup_n_u32_z(svcmple_n_f32(pg, n, 0.0), 0x82000000);
|
||||
const svfloat32_t s1 = svreinterpret_f32_u32(svadd_n_u32_x(pg, d, 0x7f000000));
|
||||
const svfloat32_t s2 = svreinterpret_f32_u32(svsub_u32_x(pg, e, d));
|
||||
return svsel_f32(svacgt_f32(pg, n, svdup_n_f32_x(pg, 192)), svmul_f32_x(pg, s1, s1),
|
||||
svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));
|
||||
}
|
||||
|
||||
// computes silu x/(1+exp(-x)) in single precision vector
|
||||
inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
|
||||
const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);
|
||||
const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);
|
||||
const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);
|
||||
const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x);
|
||||
const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);
|
||||
return svdiv_f32_x(pg, x, one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
@@ -928,7 +1225,51 @@ inline static __m128 ggml_v_silu(__m128 x) {
|
||||
return _mm_div_ps(x, one_plus_exp_neg_x);
|
||||
}
|
||||
|
||||
#endif // __ARM_NEON / __AVX2__ / __SSE2__
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
|
||||
// adapted from arm limited optimized routine
|
||||
// the maximum error is 1.45358 plus 0.5 ulps
|
||||
// numbers above 88.38 will flush to infinity
|
||||
// numbers beneath -103.97 will flush to zero
|
||||
inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
|
||||
const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);
|
||||
#ifdef __riscv_xtheadvector
|
||||
// workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
|
||||
vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);
|
||||
z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);
|
||||
#else
|
||||
const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);
|
||||
#endif
|
||||
const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);
|
||||
const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),
|
||||
0x1.7f7d1cp-20f, n, vl);
|
||||
const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);
|
||||
const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f
|
||||
const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);
|
||||
const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);
|
||||
const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(
|
||||
__riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),
|
||||
__riscv_vfmacc_vv_f32m2(
|
||||
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),
|
||||
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),
|
||||
u, vl), u, vl);
|
||||
if (!__riscv_vcpop_m_b16(c, vl))
|
||||
return __riscv_vfmacc_vv_f32m2(k, j, k, vl);
|
||||
const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);
|
||||
const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);
|
||||
const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));
|
||||
const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));
|
||||
const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(
|
||||
__riscv_vfmacc_vv_f32m2(k, k, j, vl),
|
||||
__riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),
|
||||
c, vl);
|
||||
return __riscv_vmerge_vvm_f32m2(
|
||||
r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),
|
||||
__riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),
|
||||
vl);
|
||||
}
|
||||
|
||||
#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
|
||||
|
||||
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
|
||||
@@ -94,7 +94,11 @@ if (CUDAToolkit_FOUND)
|
||||
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
|
||||
else ()
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
else()
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
|
||||
|
||||
+275
-170
@@ -1,5 +1,6 @@
|
||||
#include "binbcast.cuh"
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
||||
return b;
|
||||
@@ -22,13 +23,16 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
|
||||
return a / b;
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
|
||||
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
|
||||
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13) {
|
||||
const int ne0, const int ne1, const int ne2, const int ne3,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
/*int s0, */ const int s1, const int s2, const int s3,
|
||||
/*int s00,*/ const int s01, const int s02, const int s03,
|
||||
/*int s10,*/ const int s11, const int s12, const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
|
||||
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
|
||||
@@ -46,24 +50,31 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
}
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13) {
|
||||
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
|
||||
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
const int ne0, const int ne1, const int ne2,const int ne3,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
/*int s0, */ const int s1, const int s2, const int s3,
|
||||
/*int s00,*/ const int s01, const int s02, const int s03,
|
||||
/*int s10,*/ const int s11, const int s12, const int s13,
|
||||
src1_ptrs ... src1s) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const int i3 = i/(ne2*ne1*ne0);
|
||||
@@ -83,12 +94,190 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
}
|
||||
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
|
||||
static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
|
||||
cudaStream_t stream, std::index_sequence<I...>) {
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
int nr0 = ne10 / ne0;
|
||||
int nr1 = ne11 / ne1;
|
||||
int nr2 = ne12 / ne2;
|
||||
int nr3 = ne13 / ne3;
|
||||
|
||||
int nr[4] = { nr0, nr1, nr2, nr3 };
|
||||
|
||||
int64_t cne[] = { ne0, ne1, ne2, ne3 };
|
||||
int64_t cne0[] = { ne00, ne01, ne02, ne03 };
|
||||
int64_t cne1[] = { ne10, ne11, ne12, ne13 };
|
||||
|
||||
size_t cnb[] = { nb0, nb1, nb2, nb3 };
|
||||
size_t cnb0[] = { nb00, nb01, nb02, nb03 };
|
||||
size_t cnb1[] = { nb10, nb11, nb12, nb13 };
|
||||
|
||||
auto collapse = [](int64_t cne[]) {
|
||||
cne[0] *= cne[1];
|
||||
cne[1] = cne[2];
|
||||
cne[2] = cne[3];
|
||||
cne[3] = 1;
|
||||
};
|
||||
|
||||
auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
|
||||
cnb[1] *= cne[1];
|
||||
cnb[2] *= cne[2];
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
}
|
||||
if (i > 0) {
|
||||
collapse_nb(cnb, cne);
|
||||
collapse_nb(cnb0, cne0);
|
||||
collapse_nb(cnb1, cne1);
|
||||
collapse(cne);
|
||||
collapse(cne0);
|
||||
collapse(cne1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
int64_t ne0 = cne[0];
|
||||
int64_t ne1 = cne[1];
|
||||
int64_t ne2 = cne[2];
|
||||
int64_t ne3 = cne[3];
|
||||
|
||||
//int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
|
||||
//int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
|
||||
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
|
||||
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
|
||||
|
||||
int64_t ne10 = cne1[0];
|
||||
int64_t ne11 = cne1[1];
|
||||
int64_t ne12 = cne1[2];
|
||||
int64_t ne13 = cne1[3];
|
||||
|
||||
size_t nb0 = cnb[0];
|
||||
size_t nb1 = cnb[1];
|
||||
size_t nb2 = cnb[2];
|
||||
size_t nb3 = cnb[3];
|
||||
|
||||
size_t nb00 = cnb0[0];
|
||||
size_t nb01 = cnb0[1];
|
||||
size_t nb02 = cnb0[2];
|
||||
size_t nb03 = cnb0[3];
|
||||
|
||||
size_t nb10 = cnb1[0];
|
||||
size_t nb11 = cnb1[1];
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
size_t s00 = nb00 / sizeof(src0_t);
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s00 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
|
||||
|
||||
dim3 block_dims;
|
||||
block_dims.x = std::min<unsigned int>(hne0, block_size);
|
||||
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
|
||||
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
|
||||
|
||||
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
|
||||
(ne1 + block_dims.y - 1) / block_dims.y,
|
||||
(ne2 * ne3 + block_dims.z - 1) / block_dims.z);
|
||||
|
||||
if (block_nums.z > 65535) {
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12,s13,
|
||||
(const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12,s13);
|
||||
}
|
||||
} else {
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12,s13,
|
||||
(const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12,s13);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -120,160 +309,14 @@ static __global__ void k_repeat_back(
|
||||
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
template <float (*bin_op)(const float, const float), int n_fuse = 1>
|
||||
struct bin_bcast_cuda {
|
||||
template<typename src0_t, typename src1_t, typename dst_t>
|
||||
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
|
||||
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
|
||||
cudaStream_t stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
int nr0 = ne10/ne0;
|
||||
int nr1 = ne11/ne1;
|
||||
int nr2 = ne12/ne2;
|
||||
int nr3 = ne13/ne3;
|
||||
|
||||
int nr[4] = { nr0, nr1, nr2, nr3 };
|
||||
|
||||
// collapse dimensions until first broadcast dimension
|
||||
int64_t cne[] = {ne0, ne1, ne2, ne3};
|
||||
int64_t cne0[] = {ne00, ne01, ne02, ne03};
|
||||
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
||||
|
||||
size_t cnb[] = {nb0, nb1, nb2, nb3};
|
||||
size_t cnb0[] = {nb00, nb01, nb02, nb03};
|
||||
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
||||
|
||||
auto collapse = [](int64_t cne[]) {
|
||||
cne[0] *= cne[1];
|
||||
cne[1] = cne[2];
|
||||
cne[2] = cne[3];
|
||||
cne[3] = 1;
|
||||
};
|
||||
|
||||
auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
|
||||
cnb[1] *= cne[1];
|
||||
cnb[2] *= cne[2];
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
}
|
||||
if (i > 0) {
|
||||
collapse_nb(cnb, cne);
|
||||
collapse_nb(cnb0, cne0);
|
||||
collapse_nb(cnb1, cne1);
|
||||
collapse(cne);
|
||||
collapse(cne0);
|
||||
collapse(cne1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
int64_t ne0 = cne[0];
|
||||
int64_t ne1 = cne[1];
|
||||
int64_t ne2 = cne[2];
|
||||
int64_t ne3 = cne[3];
|
||||
|
||||
//int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
|
||||
//int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
|
||||
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
|
||||
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
|
||||
|
||||
int64_t ne10 = cne1[0];
|
||||
int64_t ne11 = cne1[1];
|
||||
int64_t ne12 = cne1[2];
|
||||
int64_t ne13 = cne1[3];
|
||||
|
||||
size_t nb0 = cnb[0];
|
||||
size_t nb1 = cnb[1];
|
||||
size_t nb2 = cnb[2];
|
||||
size_t nb3 = cnb[3];
|
||||
|
||||
size_t nb00 = cnb0[0];
|
||||
size_t nb01 = cnb0[1];
|
||||
size_t nb02 = cnb0[2];
|
||||
size_t nb03 = cnb0[3];
|
||||
|
||||
size_t nb10 = cnb1[0];
|
||||
size_t nb11 = cnb1[1];
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
size_t s00 = nb00 / sizeof(src0_t);
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s00 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
||||
|
||||
dim3 block_dims;
|
||||
block_dims.x = std::min<unsigned int>(hne0, block_size);
|
||||
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
|
||||
block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
|
||||
|
||||
dim3 block_nums(
|
||||
(hne0 + block_dims.x - 1) / block_dims.x,
|
||||
(ne1 + block_dims.y - 1) / block_dims.y,
|
||||
(ne2*ne3 + block_dims.z - 1) / block_dims.z
|
||||
);
|
||||
|
||||
if (block_nums.z > 65535) {
|
||||
// this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
|
||||
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
||||
k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00, */ s01, s02, s03,
|
||||
/* s10, */ s11, s12, s13);
|
||||
} else {
|
||||
k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd,
|
||||
ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00, */ s01, s02, s03,
|
||||
/* s10, */ s11, s12, s13);
|
||||
}
|
||||
}
|
||||
launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
|
||||
src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -312,7 +355,7 @@ static void ggml_cuda_op_bin_bcast(
|
||||
}
|
||||
|
||||
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
|
||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
|
||||
}
|
||||
|
||||
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -331,6 +374,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
||||
}
|
||||
|
||||
template <float (*op)(const float, const float), int n_fuse>
|
||||
static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
|
||||
(const float *) src0->data, (const float *) src1->data, (float *) dst->data,
|
||||
stream, std::make_index_sequence<n_fuse>{});
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
|
||||
(const half *) src0->data, (const half *) src1->data, (half *) dst->data,
|
||||
stream, std::make_index_sequence<n_fuse>{});
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
|
||||
launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
|
||||
(const half *) src0->data, (const float *) src1->data, (half *) dst->data,
|
||||
stream, std::make_index_sequence<n_fuse>{});
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
|
||||
(const half *) src0->data, (const float *) src1->data, (float *) dst->data,
|
||||
stream, std::make_index_sequence<n_fuse>{});
|
||||
} else {
|
||||
fprintf(stderr,
|
||||
"%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
|
||||
__func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
|
||||
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
|
||||
|
||||
switch (n_fuse) {
|
||||
case 2:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
|
||||
break;
|
||||
case 3:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
|
||||
break;
|
||||
case 4:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
|
||||
break;
|
||||
case 5:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
|
||||
break;
|
||||
case 6:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
|
||||
break;
|
||||
case 7:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
|
||||
break;
|
||||
case 8:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "Unsupported n_fuse value");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
|
||||
@@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
|
||||
|
||||
@@ -107,9 +107,9 @@ constexpr bool ggml_cuda_has_arch(const int arch) {
|
||||
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
|
||||
}
|
||||
|
||||
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
|
||||
constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) {
|
||||
if (cur == 0) {
|
||||
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
|
||||
return -1;
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
@@ -420,16 +420,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_all(int x) {
|
||||
#ifdef GGML_USE_HIP
|
||||
if (width == ggml_cuda_get_physical_warp_size()) {
|
||||
return __all_sync(0xffffffff, x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_any(int x) {
|
||||
if (width == ggml_cuda_get_physical_warp_size()) {
|
||||
return __any_sync(0xffffffff, x);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int offset = width/2; offset > 0; offset >>= 1) {
|
||||
x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
return x;
|
||||
#else
|
||||
static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
|
||||
return __all_sync(0xffffffff, x);
|
||||
#endif // GGML_USE_HIP
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
#include "conv2d.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
struct conv_params {
|
||||
const int64_t IW, IH;
|
||||
const int64_t OW, OH;
|
||||
const int64_t KW, KH;
|
||||
const int64_t ST_X, ST_Y;
|
||||
const int64_t PD_X, PD_Y;
|
||||
const int64_t DL_X, DL_Y;
|
||||
const int64_t IC, OC;
|
||||
const int64_t B;
|
||||
const int64_t TOTAL;
|
||||
};
|
||||
|
||||
struct kernel_bounds {
|
||||
int64_t y_min, y_max;
|
||||
int64_t x_min, x_max;
|
||||
};
|
||||
|
||||
__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
|
||||
kernel_bounds bounds;
|
||||
bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
|
||||
bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
|
||||
bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
|
||||
bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
|
||||
return bounds;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
|
||||
int64_t kern_coord,
|
||||
int64_t stride,
|
||||
int64_t dilation,
|
||||
int64_t padding) {
|
||||
return out_coord * stride + kern_coord * dilation - padding;
|
||||
}
|
||||
|
||||
struct whcn_layout {
|
||||
__device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
|
||||
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
|
||||
}
|
||||
|
||||
__device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
|
||||
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
|
||||
}
|
||||
|
||||
__device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
|
||||
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
|
||||
}
|
||||
|
||||
__device__ static void unpack_indices(int64_t global_idx,
|
||||
const conv_params & P,
|
||||
int64_t & n,
|
||||
int64_t & c,
|
||||
int64_t & out_y,
|
||||
int64_t & out_x) {
|
||||
out_x = global_idx % P.OW;
|
||||
out_y = (global_idx / P.OW) % P.OH;
|
||||
c = (global_idx / (P.OW * P.OH)) % P.OC;
|
||||
n = global_idx / (P.OW * P.OH * P.OC);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Layout>
|
||||
static __global__ void conv2d_kernel(const float * __restrict__ input,
|
||||
const T * __restrict__ kernel,
|
||||
float * __restrict__ output,
|
||||
const conv_params P) {
|
||||
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (global_idx >= P.TOTAL) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t n, c_out, out_y, out_x;
|
||||
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
|
||||
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
|
||||
|
||||
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
|
||||
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
|
||||
|
||||
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
|
||||
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
|
||||
|
||||
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
|
||||
const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
|
||||
acc += (input_val * ggml_cuda_cast<float>(kernel_val));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [N, OC, OH, OW]
|
||||
output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
|
||||
const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
|
||||
conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
|
||||
}
|
||||
|
||||
static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
|
||||
conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
|
||||
}
|
||||
|
||||
static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
|
||||
conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * kernel = dst->src[0];
|
||||
const ggml_tensor * input = dst->src[1];
|
||||
float * K_D = (float *) kernel->data;
|
||||
const float * X_D = (const float *) input->data;
|
||||
float * Y_D = (float *) dst->data;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(kernel));
|
||||
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
|
||||
|
||||
// same number of input channels
|
||||
GGML_ASSERT(input->ne[2] == kernel->ne[2]);
|
||||
|
||||
cudaStream_t st = ctx.stream();
|
||||
|
||||
const int32_t * p = (const int32_t *) dst->op_params;
|
||||
const int ST_X = p[0]; // stride_x
|
||||
const int ST_Y = p[1]; // stride_y
|
||||
const int PD_X = p[2]; // padding_x
|
||||
const int PD_Y = p[3]; // padding_y
|
||||
const int DL_X = p[4]; // dilation_x
|
||||
const int DL_Y = p[5]; // dilation_y
|
||||
|
||||
// No cwhn
|
||||
GGML_ASSERT(p[6] == false);
|
||||
|
||||
const int IW = input->ne[0]; // input_w
|
||||
const int IH = input->ne[1]; // input_h
|
||||
const int OW = dst->ne[0]; // output_w
|
||||
const int OH = dst->ne[1]; // output_h
|
||||
const int KW = kernel->ne[0]; // kernel_w
|
||||
const int KH = kernel->ne[1]; // kernel_h
|
||||
const int IC = input->ne[2]; // input_channels
|
||||
const int OC = kernel->ne[3]; // ouptut_chanles
|
||||
const int B = input->ne[3]; // n_batches
|
||||
|
||||
const int64_t total = B * OC * OH * OW;
|
||||
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
|
||||
|
||||
if (kernel->type == GGML_TYPE_F16) {
|
||||
conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
|
||||
} else {
|
||||
conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CONV2D_BLOCK_SIZE 256
|
||||
void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -258,7 +258,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const half val = hexp(sink - kqmax[j0/nwarps]);
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
|
||||
kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ggml-cuda/clamp.cuh"
|
||||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/conv2d.cuh"
|
||||
#include "ggml-cuda/conv2d-dw.cuh"
|
||||
#include "ggml-cuda/conv2d-transpose.cuh"
|
||||
#include "ggml-cuda/convert.cuh"
|
||||
@@ -49,6 +50,7 @@
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -203,6 +205,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
|
||||
#endif // GGML_CUDA_FORCE_CUBLAS
|
||||
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
|
||||
|
||||
std::vector<std::pair<int, std::string>> turing_devices_without_mma;
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
int device_vmm = 0;
|
||||
|
||||
@@ -260,7 +264,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
std::string device_name(prop.name);
|
||||
if (device_name == "NVIDIA GeForce MX450") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
} else if (device_name == "NVIDIA GeForce MX550") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
|
||||
if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
|
||||
GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
|
||||
for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
|
||||
GGML_LOG_INFO(
|
||||
" Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
|
||||
}
|
||||
GGML_LOG_INFO(
|
||||
"Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
|
||||
}
|
||||
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
@@ -2352,6 +2374,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_PAD:
|
||||
ggml_cuda_op_pad(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
ggml_cuda_op_pad_reflect_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_cuda_op_arange(ctx, dst);
|
||||
break;
|
||||
@@ -2427,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_IM2COL:
|
||||
ggml_cuda_op_im2col(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_2D:
|
||||
ggml_cuda_op_conv2d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
ggml_cuda_op_conv2d_dw(ctx, dst);
|
||||
break;
|
||||
@@ -2793,9 +2821,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
||||
if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
||||
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
||||
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
|
||||
const ggml_tensor *add = nullptr;
|
||||
|
||||
if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
|
||||
add = cgraph->nodes[node_idx+2];
|
||||
}
|
||||
|
||||
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
||||
@@ -2807,6 +2840,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (add && (add->src[0]->type != GGML_TYPE_F32 ||
|
||||
add->src[1]->type != GGML_TYPE_F32 ||
|
||||
add->type != GGML_TYPE_F32) ) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//if rms norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
return false;
|
||||
@@ -2817,6 +2856,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -2863,7 +2906,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion) {
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
|
||||
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
std::fill(ops, ops + 8, GGML_OP_ADD);
|
||||
|
||||
for (; n_fuse <= 6; ++n_fuse){
|
||||
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
n_fuse++;
|
||||
|
||||
if (n_fuse > 1) {
|
||||
for (int j = 0; j < n_fuse - 1; ++j) {
|
||||
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
}
|
||||
cgraph->nodes[i + n_fuse - 1]->data = node->data;
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
|
||||
i += n_fuse - 1;
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
|
||||
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
|
||||
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
@@ -3082,7 +3164,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)
|
||||
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
@@ -3477,19 +3559,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
|
||||
}
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
|
||||
+177
-47
@@ -3,6 +3,140 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||
struct mmq_ids_helper_store {
|
||||
uint32_t data;
|
||||
|
||||
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||
}
|
||||
|
||||
__device__ uint32_t it() const {
|
||||
return data & 0x003FFFFF;
|
||||
}
|
||||
|
||||
__device__ uint32_t iex_used() const {
|
||||
return data >> 22;
|
||||
}
|
||||
};
|
||||
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
|
||||
|
||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||
// ids_dst describes the same mapping but for the dst tensor.
|
||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||
template <int n_expert_used_template>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||
const int expert = blockIdx.x;
|
||||
|
||||
extern __shared__ char data_mmq_ids_helper[];
|
||||
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
|
||||
|
||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
// Generic implementation:
|
||||
for (int it = 0; it < n_tokens; ++it) {
|
||||
int iex_used = -1; // The index at which the expert is used, if any.
|
||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||
const int expert_used = ids[it*si1 + iex];
|
||||
nex_prev += expert_used < expert;
|
||||
if (expert_used == expert) {
|
||||
iex_used = iex;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||
it_compact++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Implementation optimized for specific numbers of experts used:
|
||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||
const int it = it0 + threadIdx.x / neu_padded;
|
||||
|
||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||
ids[it*si1 + iex] : INT_MAX;
|
||||
const int iex_used = expert_used == expert ? iex : -1;
|
||||
nex_prev += expert_used < expert;
|
||||
|
||||
// Whether the threads at this token position have used the expert:
|
||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||
|
||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||
int it_compact_add_lower = 0;
|
||||
#pragma unroll
|
||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||
if (threadIdx.x >= offset) {
|
||||
it_compact_add_lower += tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||
}
|
||||
}
|
||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||
|
||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||
const mmq_ids_helper_store store_it = store[itc];
|
||||
const int it = store_it.it();
|
||||
const int iex_used = store_it.iex_used();
|
||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||
}
|
||||
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[expert] = nex_prev;
|
||||
|
||||
if (expert < gridDim.x - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||
}
|
||||
|
||||
template <int n_expert_used_template>
|
||||
static void launch_mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
|
||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
|
||||
|
||||
const dim3 num_blocks(n_experts, 1, 1);
|
||||
const dim3 block_size(warp_size, 1, 1);
|
||||
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
|
||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||
switch (args.type_x) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
@@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
|
||||
ne00, ne01, ne1, s01, ne11, s1,
|
||||
ne02, ne12, s02, s12, s2,
|
||||
ne03, ne13, s03, s13, s3,
|
||||
use_stream_k};
|
||||
use_stream_k, ne1};
|
||||
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
||||
return;
|
||||
}
|
||||
@@ -148,54 +282,50 @@ void ggml_cuda_mul_mat_q(
|
||||
|
||||
const int64_t n_expert_used = ids->ne[0];
|
||||
const int64_t ne_get_rows = ne12 * n_expert_used;
|
||||
GGML_ASSERT(ne1 == n_expert_used);
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
std::vector<int32_t> ids_src1_host;
|
||||
ids_src1_host.reserve(ne_get_rows);
|
||||
std::vector<int32_t> ids_dst_host;
|
||||
ids_dst_host.reserve(ne_get_rows);
|
||||
std::vector<int32_t> tokens_per_expert_host(ne02);
|
||||
std::vector<int32_t> expert_bounds_host(ne02 + 1);
|
||||
ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
|
||||
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
|
||||
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
{
|
||||
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
|
||||
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
||||
const int sis1 = nb12 / nb11;
|
||||
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
|
||||
for (int64_t iex = 0; iex < n_expert_used; ++iex) {
|
||||
const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
|
||||
assert(expert_to_use >= 0 && expert_to_use < ne02);
|
||||
if (expert_to_use == i02) {
|
||||
ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
|
||||
ids_dst_host.push_back(i12*ne1 + iex);
|
||||
tokens_per_expert_host[i02]++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
switch (n_expert_used) {
|
||||
case 2:
|
||||
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
default:
|
||||
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
int32_t cumsum = 0;
|
||||
for (int64_t i = 0; i < ne02; ++i) {
|
||||
expert_bounds_host[i] = cumsum;
|
||||
cumsum += tokens_per_expert_host[i];
|
||||
}
|
||||
expert_bounds_host[ne02] = cumsum;
|
||||
|
||||
std::vector<int32_t> ids_buf_host;
|
||||
ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
|
||||
ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
|
||||
ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
|
||||
ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
|
||||
ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
|
||||
CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
const int32_t * ids_src1_dev = ids_buf_dev.ptr;
|
||||
const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
|
||||
const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
|
||||
|
||||
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
||||
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
||||
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
|
||||
@@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s12 = src1->nb[2] / ts_src1;
|
||||
const int64_t s13 = src1->nb[2] / ts_src1;
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
|
||||
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
|
||||
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
@@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
|
||||
|
||||
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
||||
const mmq_args args = {
|
||||
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
|
||||
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
|
||||
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
|
||||
ne02, ne02, s02, s12, s2,
|
||||
ne03, ne13, s03, s13, s3,
|
||||
use_stream_k};
|
||||
use_stream_k, ne12};
|
||||
|
||||
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
||||
}
|
||||
@@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
|
||||
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
|
||||
1, 1, 0, 0, 0,
|
||||
1, 1, 0, 0, 0,
|
||||
use_stream_k};
|
||||
use_stream_k, src1_ncols};
|
||||
|
||||
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
||||
|
||||
|
||||
+21
-13
@@ -3138,7 +3138,8 @@ static __global__ void mul_mat_q(
|
||||
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
|
||||
// Skip unused template specializations for faster compilation:
|
||||
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
||||
@@ -3152,7 +3153,7 @@ static __global__ void mul_mat_q(
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
|
||||
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
||||
|
||||
// Initialize the ids for writing back data with just the index.
|
||||
@@ -3376,7 +3377,8 @@ template <ggml_type type, int mmq_x, bool need_check>
|
||||
static __global__ void mul_mat_q_stream_k_fixup(
|
||||
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
|
||||
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
|
||||
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
||||
@@ -3387,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
||||
|
||||
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
||||
|
||||
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
|
||||
const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
|
||||
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -3528,7 +3530,7 @@ struct mmq_args {
|
||||
int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
|
||||
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
|
||||
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
|
||||
bool use_stream_k;
|
||||
bool use_stream_k; int64_t ncols_max;
|
||||
};
|
||||
|
||||
template<ggml_type type>
|
||||
@@ -3558,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
|
||||
|
||||
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
||||
const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
|
||||
const int ntzw = args.nchannels_y * args.nsamples_y;
|
||||
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
|
||||
|
||||
@@ -3574,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -3601,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
@@ -3609,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
||||
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
||||
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
||||
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
||||
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
|
||||
if (!fixup_needed) {
|
||||
return;
|
||||
@@ -3624,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
|
||||
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
||||
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
||||
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
|
||||
args.ncols_max);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3649,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||
continue;
|
||||
}
|
||||
|
||||
const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
|
||||
const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
|
||||
|
||||
if (ntiles_x < ntiles_x_best) {
|
||||
mmq_x_best = mmq_x;
|
||||
|
||||
+192
-19
@@ -104,12 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size, bool do_multiply = false>
|
||||
static __global__ void rms_norm_f32(
|
||||
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
|
||||
const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
|
||||
const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
|
||||
template <int block_size, bool do_multiply = false, bool do_add = false>
|
||||
static __global__ void rms_norm_f32(const float * x, float * dst,
|
||||
const int ncols,
|
||||
const int64_t stride_row,
|
||||
const int64_t stride_channel,
|
||||
const int64_t stride_sample,
|
||||
const float eps,
|
||||
const float * mul = nullptr,
|
||||
const int64_t mul_stride_row = 0,
|
||||
const int64_t mul_stride_channel = 0,
|
||||
const int64_t mul_stride_sample = 0,
|
||||
const int mul_ncols = 0,
|
||||
const int mul_nrows = 0,
|
||||
const int mul_nchannels = 0,
|
||||
const int mul_nsamples = 0,
|
||||
const float * add = nullptr,
|
||||
const int64_t add_stride_row = 0,
|
||||
const int64_t add_stride_channel = 0,
|
||||
const int64_t add_stride_sample = 0,
|
||||
const int add_ncols = 0,
|
||||
const int add_nrows = 0,
|
||||
const int add_nchannels = 0,
|
||||
const int add_nsamples = 0) {
|
||||
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
|
||||
@@ -118,6 +136,8 @@ static __global__ void rms_norm_f32(
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
@@ -128,6 +148,13 @@ static __global__ void rms_norm_f32(
|
||||
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
|
||||
}
|
||||
|
||||
if constexpr (do_add) {
|
||||
const int add_row = row % add_nrows;
|
||||
const int add_channel = channel % add_nchannels;
|
||||
const int add_sample = sample % add_nsamples;
|
||||
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
|
||||
}
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
@@ -154,7 +181,11 @@ static __global__ void rms_norm_f32(
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
if constexpr (do_multiply) {
|
||||
if constexpr (do_multiply && do_add) {
|
||||
const int mul_col = col % mul_ncols;
|
||||
const int add_col = col % add_ncols;
|
||||
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
|
||||
} else if constexpr (do_multiply) {
|
||||
const int mul_col = col % mul_ncols;
|
||||
dst[col] = scale * x[col] * mul[mul_col];
|
||||
} else {
|
||||
@@ -331,23 +362,70 @@ static void rms_norm_f32_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_mul_f32_cuda(
|
||||
const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
|
||||
const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
|
||||
const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
|
||||
const float eps, cudaStream_t stream) {
|
||||
static void rms_norm_mul_f32_cuda(const float * x,
|
||||
const float * mul,
|
||||
const float * add,
|
||||
float * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
const int nchannels,
|
||||
const int nsamples,
|
||||
const int64_t stride_row,
|
||||
const int64_t stride_channel,
|
||||
const int64_t stride_sample,
|
||||
const int64_t mul_stride_row,
|
||||
const int64_t mul_stride_channel,
|
||||
const int64_t mul_stride_sample,
|
||||
const int mul_ncols,
|
||||
const int mul_nrows,
|
||||
const int mul_nchannels,
|
||||
const int mul_nsamples,
|
||||
const int64_t add_stride_row,
|
||||
const int64_t add_stride_channel,
|
||||
const int64_t add_stride_sample,
|
||||
const int add_ncols,
|
||||
const int add_nrows,
|
||||
const int add_nchannels,
|
||||
const int add_nsamples,
|
||||
const float eps,
|
||||
cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (mul == nullptr) {
|
||||
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
|
||||
return;
|
||||
}
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
||||
if (add == nullptr) {
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
|
||||
ncols, stride_row, stride_channel, stride_sample, eps,
|
||||
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
||||
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
|
||||
ncols, stride_row, stride_channel, stride_sample, eps,
|
||||
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
||||
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
||||
}
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
|
||||
ncols, stride_row, stride_channel, stride_sample, eps,
|
||||
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
||||
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
|
||||
add, add_stride_row, add_stride_channel, add_stride_sample,
|
||||
add_ncols, add_nrows, add_nchannels, add_nsamples);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
|
||||
ncols, stride_row, stride_channel, stride_sample, eps,
|
||||
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
||||
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
|
||||
add, add_stride_row, add_stride_channel, add_stride_sample,
|
||||
add_ncols, add_nrows, add_nchannels, add_nsamples);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,7 +569,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
const int mul_nchannels = mul_src->ne[2];
|
||||
const int mul_nsamples = mul_src->ne[3];
|
||||
|
||||
rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
|
||||
rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
/*s00*/ s01, s02, s03,
|
||||
/*mul_s00*/ mul_s01, mul_s02, mul_s03,
|
||||
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
|
||||
/*add_s00*/ 0, 0, 0,
|
||||
0, 0, 0, 0,
|
||||
eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * dst,
|
||||
ggml_tensor * mul_tensor,
|
||||
ggml_tensor * add_tensor) {
|
||||
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
|
||||
float eps = 0.0f;
|
||||
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const float * src0_d = (const float *) rms_norm_src->data;
|
||||
const float * mul_d = nullptr;
|
||||
const ggml_tensor * mul_src = nullptr;
|
||||
|
||||
if (mul_tensor->src[0] == dst) {
|
||||
mul_d = (float *) mul_tensor->src[1]->data;
|
||||
mul_src = mul_tensor->src[1];
|
||||
} else if (mul_tensor->src[1] == dst) {
|
||||
mul_d = (float *) mul_tensor->src[0]->data;
|
||||
mul_src = mul_tensor->src[0];
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
const float * add_d = nullptr;
|
||||
const ggml_tensor * add_src = nullptr;
|
||||
|
||||
if (add_tensor->src[0] == mul_tensor) {
|
||||
add_d = (float *) add_tensor->src[1]->data;
|
||||
add_src = add_tensor->src[1];
|
||||
} else if (add_tensor->src[1] == mul_tensor) {
|
||||
add_d = (float *) add_tensor->src[0]->data;
|
||||
add_src = add_tensor->src[0];
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
float * dst_d = (float *) add_tensor->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
const int64_t ne00 = rms_norm_src->ne[0];
|
||||
const int64_t ne01 = rms_norm_src->ne[1];
|
||||
const int64_t ne02 = rms_norm_src->ne[2];
|
||||
const int64_t ne03 = rms_norm_src->ne[3];
|
||||
|
||||
const size_t ts0 = ggml_type_size(rms_norm_src->type);
|
||||
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
|
||||
const int64_t s01 = rms_norm_src->nb[1] / ts0;
|
||||
const int64_t s02 = rms_norm_src->nb[2] / ts0;
|
||||
const int64_t s03 = rms_norm_src->nb[3] / ts0;
|
||||
|
||||
const size_t ts_mul = ggml_type_size(mul_src->type);
|
||||
GGML_ASSERT(mul_src->nb[0] == ts_mul);
|
||||
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
|
||||
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
|
||||
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
|
||||
|
||||
const int mul_ncols = mul_src->ne[0];
|
||||
const int mul_nrows = mul_src->ne[1];
|
||||
const int mul_nchannels = mul_src->ne[2];
|
||||
const int mul_nsamples = mul_src->ne[3];
|
||||
|
||||
const size_t ts_add = ggml_type_size(add_src->type);
|
||||
GGML_ASSERT(add_src->nb[0] == ts_add);
|
||||
const int64_t add_s01 = add_src->nb[1] / ts_add;
|
||||
const int64_t add_s02 = add_src->nb[2] / ts_add;
|
||||
const int64_t add_s03 = add_src->nb[3] / ts_add;
|
||||
|
||||
const int add_ncols = add_src->ne[0];
|
||||
const int add_nrows = add_src->ne[1];
|
||||
const int add_nchannels = add_src->ne[2];
|
||||
const int add_nsamples = add_src->ne[3];
|
||||
|
||||
rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
|
||||
ne00,ne01, ne02, ne03,
|
||||
/*s00*/ s01, s02, s03,
|
||||
/*mul_s00*/ mul_s01, mul_s02, mul_s03,
|
||||
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
|
||||
/*add_s00*/ add_s01, add_s02, add_s03,
|
||||
add_ncols, add_nrows, add_nchannels, add_nsamples,
|
||||
eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
|
||||
|
||||
void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * dst,
|
||||
ggml_tensor * mul_tensor,
|
||||
ggml_tensor * add_tensor);
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
#include "pad_reflect_1d.cuh"
|
||||
|
||||
static __global__ void pad_reflect_1d_kernel_f32(
|
||||
const void * __restrict__ src0,
|
||||
void * __restrict__ dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne00,
|
||||
const int64_t ne01,
|
||||
const int64_t ne02,
|
||||
const int64_t ne03,
|
||||
const int64_t nb00,
|
||||
const int64_t nb01,
|
||||
const int64_t nb02,
|
||||
const int64_t nb03,
|
||||
const int64_t nb0,
|
||||
const int64_t nb1,
|
||||
const int64_t nb2,
|
||||
const int64_t nb3,
|
||||
const int p0,
|
||||
const int p1) {
|
||||
|
||||
const int64_t i3 = blockIdx.z;
|
||||
const int64_t i2 = blockIdx.y;
|
||||
const int64_t i1 = blockIdx.x;
|
||||
|
||||
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
|
||||
return;
|
||||
}
|
||||
|
||||
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
|
||||
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
|
||||
|
||||
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
|
||||
float value;
|
||||
|
||||
if (i0 < p0) {
|
||||
// Left padding - reflect
|
||||
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
|
||||
} else if (i0 < ne0 - p1) {
|
||||
// Middle - copy
|
||||
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
|
||||
} else {
|
||||
// Right padding - reflect
|
||||
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
|
||||
value = *(const float *)(src0_ptr + src_idx * nb00);
|
||||
}
|
||||
|
||||
*(float *)(dst_ptr + i0 * nb0) = value;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t * opts = (const int32_t *) dst->op_params;
|
||||
const int p0 = opts[0];
|
||||
const int p1 = opts[1];
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
GGML_ASSERT(ne0 == ne00 + p0 + p1);
|
||||
|
||||
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
|
||||
const dim3 grid_dims(ne01, ne02, ne03);
|
||||
|
||||
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src0->data, dst->data,
|
||||
ne0, ne00, ne01, ne02, ne03,
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
|
||||
p0, p1
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1)
|
||||
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
|
||||
const int seq_idx = blockIdx.y;
|
||||
|
||||
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
|
||||
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
|
||||
|
||||
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
|
||||
|
||||
@@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
|
||||
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
// q4 contains 8 indices with 4 bit each.
|
||||
// This function selects those bytes from table that are at those indices and returns them as int2.
|
||||
// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
|
||||
#if defined(GGML_USE_HIP)
|
||||
// Load the 16-byte table into four 32-bit unsigned integers.
|
||||
const uint32_t *values = (const uint32_t *)table;
|
||||
|
||||
const uint32_t q_even = q4;
|
||||
const uint32_t q_odd = (q4 >> 4);
|
||||
|
||||
// Perform lookups in the lower half of the table (indices 0-7).
|
||||
uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
|
||||
uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
|
||||
|
||||
// Perform lookups in the upper half of the table (indices 8-15).
|
||||
uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
|
||||
uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
|
||||
|
||||
// Select between the low and high results based on the MSB of each index nibble.
|
||||
uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
|
||||
uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
|
||||
uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
|
||||
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
|
||||
|
||||
return make_int2(res_x, res_y);
|
||||
#elif !defined(GGML_USE_MUSA)
|
||||
// CUDA does not have an instruction for selecting bytes with 4 bit indices.
|
||||
// However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
|
||||
const uint32_t * table32 = (const uint32_t *) table;
|
||||
|
||||
// __byte_perm selects bytes based on the lower 16 bits in its third argument.
|
||||
// Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
|
||||
// To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
|
||||
// Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
|
||||
uint32_t tmp[2];
|
||||
const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; ++i) {
|
||||
const uint32_t shift = 16 * i;
|
||||
|
||||
const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
|
||||
const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
|
||||
tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
|
||||
}
|
||||
|
||||
// tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
|
||||
// However, for the result we need ints with all even/odd 4 bit indices in q4.
|
||||
// Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
|
||||
return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
|
||||
#else
|
||||
// Generic implementation.
|
||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||
const char4 val0_8 = make_char4(
|
||||
@@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
|
||||
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
#endif
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
|
||||
Vendored
+3
@@ -22,7 +22,10 @@
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define __all_sync(mask, var) __all(var)
|
||||
#define __any_sync(mask, var) __any(var)
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
|
||||
@@ -249,6 +249,7 @@ typedef struct {
|
||||
uint64_t nb33;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
@@ -257,6 +258,11 @@ typedef struct {
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t nrows;
|
||||
int32_t ne20;
|
||||
} ggml_metal_kargs_flash_attn_ext_reduce;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
@@ -320,40 +326,31 @@ typedef struct {
|
||||
} ggml_metal_kargs_mul_mv_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne02;
|
||||
int32_t ne10;
|
||||
int32_t ne11; // n_expert_used (bcast)
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t neh11; // n_tokens
|
||||
uint64_t nbh11;
|
||||
int32_t ne21; // n_tokens
|
||||
int32_t ne20; // n_expert_used
|
||||
uint64_t nb21;
|
||||
} ggml_metal_kargs_mul_mm_id_map0;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne20; // n_expert_used
|
||||
int32_t neh0;
|
||||
int32_t neh1;
|
||||
uint64_t nbh1;
|
||||
uint64_t nbh2;
|
||||
int32_t ne0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
} ggml_metal_kargs_mul_mm_id_map1;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t neh12;
|
||||
uint64_t nbh10;
|
||||
uint64_t nbh11;
|
||||
uint64_t nbh12;
|
||||
uint64_t nbh13;
|
||||
int32_t neh0;
|
||||
int32_t neh1;
|
||||
int32_t ne11;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne20;
|
||||
int32_t ne21;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mm_id;
|
||||
|
||||
+216
-132
@@ -93,35 +93,37 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
if (ctx->mtl_device == nil) {
|
||||
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
||||
|
||||
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
if (ctx->mtl_device) {
|
||||
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
|
||||
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
|
||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
||||
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
||||
#endif
|
||||
|
||||
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
ctx->use_bfloat = ctx->has_bfloat;
|
||||
ctx->use_bfloat = ctx->has_bfloat;
|
||||
#else
|
||||
ctx->use_bfloat = false;
|
||||
ctx->use_bfloat = false;
|
||||
#endif
|
||||
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
ctx->debug_fusion = val ? atoi(val) : 0;
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
ctx->debug_fusion = val ? atoi(val) : 0;
|
||||
}
|
||||
|
||||
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
||||
|
||||
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||
|
||||
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
||||
}
|
||||
|
||||
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
||||
|
||||
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||
|
||||
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
||||
}
|
||||
|
||||
ctx->mtl_device_ref_count++;
|
||||
@@ -289,6 +291,10 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
||||
@@ -396,8 +402,12 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
|
||||
@@ -443,6 +453,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||
@@ -452,6 +463,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
||||
@@ -461,6 +473,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
||||
@@ -470,6 +483,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
||||
@@ -479,6 +493,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
||||
@@ -488,6 +503,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
||||
@@ -497,6 +513,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
||||
@@ -555,6 +572,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
|
||||
GGML_METAL_KERNEL_TYPE_SET_I32,
|
||||
GGML_METAL_KERNEL_TYPE_SET_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
@@ -1304,6 +1322,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
||||
@@ -1412,8 +1434,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
|
||||
@@ -1459,6 +1485,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
|
||||
@@ -1468,6 +1495,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
||||
@@ -1477,6 +1505,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
||||
@@ -1486,6 +1515,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
||||
@@ -1495,6 +1525,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
||||
@@ -1504,6 +1535,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
||||
@@ -1513,6 +1545,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
||||
@@ -1571,6 +1604,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
@@ -1846,7 +1880,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_ROPE:
|
||||
return true;
|
||||
case GGML_OP_IM2COL:
|
||||
return op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
|
||||
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
|
||||
case GGML_OP_POOL_1D:
|
||||
return false;
|
||||
case GGML_OP_UPSCALE:
|
||||
@@ -1861,9 +1895,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_ARANGE:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
if (op->src[0]->ne[0] == 32) {
|
||||
// head size == 32 (e.g. bert-bge-small)
|
||||
// TODO: not sure if it is worth adding kernels for this size
|
||||
// for new head sizes, add checks here
|
||||
if (op->src[0]->ne[0] != 40 &&
|
||||
op->src[0]->ne[0] != 64 &&
|
||||
op->src[0]->ne[0] != 80 &&
|
||||
op->src[0]->ne[0] != 96 &&
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
op->src[0]->ne[0] != 256) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
@@ -3347,15 +3387,16 @@ static int ggml_metal_encode_node(
|
||||
|
||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||
// to the matrix-vector kernel
|
||||
const int ne11_mm_min = 4;
|
||||
const int ne11_mm_min = 8;
|
||||
|
||||
// first try to use small-batch mat-mv kernels
|
||||
// these should be efficient for BS [2, ~8]
|
||||
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
|
||||
if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
|
||||
(
|
||||
(
|
||||
(
|
||||
src0t == GGML_TYPE_F16 || // TODO: helper function
|
||||
src0t == GGML_TYPE_F32 || // TODO: helper function
|
||||
src0t == GGML_TYPE_F16 ||
|
||||
src0t == GGML_TYPE_Q4_0 ||
|
||||
src0t == GGML_TYPE_Q4_1 ||
|
||||
src0t == GGML_TYPE_Q5_0 ||
|
||||
@@ -3383,7 +3424,17 @@ static int ggml_metal_encode_node(
|
||||
// values and there can be some tail effects when nsg is high. need to confirm this
|
||||
//
|
||||
const int nsg = 2; // num simdgroups per threadgroup
|
||||
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
|
||||
|
||||
// num threads along row per simdgroup
|
||||
int nxpsg = 0;
|
||||
if (ne00 % 256 == 0 && ne11 < 3) {
|
||||
nxpsg = 16;
|
||||
} else if (ne00 % 128 == 0) {
|
||||
nxpsg = 8;
|
||||
} else {
|
||||
nxpsg = 4;
|
||||
}
|
||||
|
||||
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
|
||||
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
||||
int r1ptg = 4; // num src1 rows per threadgroup
|
||||
@@ -3406,6 +3457,14 @@ static int ggml_metal_encode_node(
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
switch (r1ptg) {
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
|
||||
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
|
||||
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
|
||||
default: GGML_ABORT("not implemented");
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
switch (r1ptg) {
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
|
||||
@@ -3560,7 +3619,7 @@ static int ggml_metal_encode_node(
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
||||
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
||||
@@ -3878,38 +3937,6 @@ static int ggml_metal_encode_node(
|
||||
default: break;
|
||||
}
|
||||
|
||||
const int64_t neh10 = ne10; // n_embd
|
||||
const int64_t neh11 = ne21; // n_tokens
|
||||
const int64_t neh12 = ne02; // n_expert
|
||||
|
||||
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
|
||||
const uint64_t nbh11 = nbh10*neh10;
|
||||
const uint64_t nbh12 = nbh11*neh11;
|
||||
const uint64_t nbh13 = nbh12*neh12;
|
||||
|
||||
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
|
||||
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
||||
if (!h_src1) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int64_t neh0 = ne0;
|
||||
const int64_t neh1 = ne21;
|
||||
const int64_t neh2 = ne02;
|
||||
|
||||
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
|
||||
const uint64_t nbh1 = nbh0*neh0;
|
||||
const uint64_t nbh2 = nbh1*neh1;
|
||||
//const uint64_t nbh3 = nbh2*neh2;
|
||||
|
||||
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
|
||||
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
||||
if (!h_dst) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// tokens per expert
|
||||
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
|
||||
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
||||
@@ -3919,8 +3946,8 @@ static int ggml_metal_encode_node(
|
||||
}
|
||||
|
||||
// id map
|
||||
// [n_expert_used, n_tokens]
|
||||
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
|
||||
// [n_tokens, n_expert]
|
||||
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
|
||||
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
||||
if (!h_ids) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
||||
@@ -3928,32 +3955,45 @@ static int ggml_metal_encode_node(
|
||||
}
|
||||
|
||||
{
|
||||
const int nth = MIN(1024, ne10/4);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id_map0 args = {
|
||||
ne02,
|
||||
ne10,
|
||||
ne11, // n_expert_used (bcast)
|
||||
ne11, // n_expert_used (bcast)
|
||||
nb11,
|
||||
nb12,
|
||||
neh11, // n_tokens
|
||||
nbh11,
|
||||
ne20, // n_expert_used
|
||||
ne21, // n_tokens
|
||||
ne20, // n_expert_used
|
||||
nb21,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
|
||||
pipeline = nil;
|
||||
|
||||
switch (ne20) {
|
||||
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
|
||||
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
|
||||
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
|
||||
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
|
||||
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
|
||||
const size_t smem = ne02*ne20*sizeof(uint16_t);
|
||||
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer: h_src1 offset:0 atIndex:3];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:4];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:5];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:2];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:3];
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
|
||||
}
|
||||
|
||||
{
|
||||
@@ -3992,13 +4032,15 @@ static int ggml_metal_encode_node(
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.neh12 =*/ neh12,
|
||||
/*.nbh10 =*/ nbh10,
|
||||
/*.nbh11 =*/ nbh11,
|
||||
/*.nbh12 =*/ nbh12,
|
||||
/*.nbh13 =*/ nbh13,
|
||||
/*.neh0 =*/ neh0,
|
||||
/*.neh1 =*/ neh1,
|
||||
/*.ne11 =*/ ne11, // n_expert_used (bcast)
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne20 =*/ ne20, // n_expert_used
|
||||
/*.ne21 =*/ ne21, // n_tokens
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.r2 =*/ r2,
|
||||
/*.r3 =*/ r3,
|
||||
};
|
||||
@@ -4006,42 +4048,14 @@ static int ggml_metal_encode_node(
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer: h_src1 offset:0 atIndex:2];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:3];
|
||||
[encoder setBuffer: h_dst offset:0 atIndex:4];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:4];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(ne0 % 4 == 0);
|
||||
|
||||
const int nth = MIN(1024, ne0/4);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id_map1 args = {
|
||||
ne20, // n_expert_used
|
||||
neh0,
|
||||
neh1,
|
||||
nbh1,
|
||||
nbh2,
|
||||
ne0,
|
||||
nb1,
|
||||
nb2,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer: h_dst offset:0 atIndex:1];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
} else {
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
@@ -4701,7 +4715,6 @@ static int ggml_metal_encode_node(
|
||||
} break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
@@ -5117,10 +5130,8 @@ static int ggml_metal_encode_node(
|
||||
|
||||
bool use_vec_kernel = false;
|
||||
|
||||
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
||||
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
||||
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
||||
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
||||
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
|
||||
if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) {
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
@@ -5130,6 +5141,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
@@ -5154,6 +5166,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
||||
@@ -5178,6 +5191,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
||||
@@ -5202,6 +5216,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
||||
@@ -5226,6 +5241,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
||||
@@ -5250,6 +5266,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
||||
@@ -5274,6 +5291,7 @@ static int ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
||||
@@ -5465,6 +5483,7 @@ static int ggml_metal_encode_node(
|
||||
/*.nb33 =*/ nb33,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
@@ -5488,7 +5507,6 @@ static int ggml_metal_encode_node(
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
if (!use_vec_kernel) {
|
||||
// half8x8 kernel
|
||||
@@ -5514,7 +5532,7 @@ static int ggml_metal_encode_node(
|
||||
|
||||
while (true) {
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > device.maxThreadgroupMemoryLength) {
|
||||
if (smem > device.maxThreadgroupMemoryLength/2) {
|
||||
break;
|
||||
}
|
||||
nsgmax *= 2;
|
||||
@@ -5526,15 +5544,18 @@ static int ggml_metal_encode_node(
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
#undef FATTN_SMEM
|
||||
} else {
|
||||
// half4x4 kernel
|
||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
@@ -5544,15 +5565,17 @@ static int ggml_metal_encode_node(
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the soft_max values and the mask
|
||||
//
|
||||
// ne00*(nsg)
|
||||
// ne20*(nsg)
|
||||
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
||||
//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
while (true) {
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > device.maxThreadgroupMemoryLength) {
|
||||
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
|
||||
if (smem > device.maxThreadgroupMemoryLength/2) {
|
||||
break;
|
||||
}
|
||||
nsgmax *= 2;
|
||||
@@ -5560,7 +5583,7 @@ static int ggml_metal_encode_node(
|
||||
nsgmax /= 2;
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||
|
||||
int64_t nsg = 1;
|
||||
while (nsg <= nsgt) {
|
||||
@@ -5568,13 +5591,74 @@ static int ggml_metal_encode_node(
|
||||
}
|
||||
nsg /= 2;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
// workgroups
|
||||
// each workgroup handles nsg*nkpsg cache values
|
||||
uint16_t nwg = 1;
|
||||
if (4*nsg*nkpsg >= ne11) {
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
||||
// using 1 workgroup -> write the result directly into dst
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
} else {
|
||||
nwg = 32;
|
||||
nsg = MIN(4, nsg);
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
||||
// sanity checks
|
||||
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
||||
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
|
||||
|
||||
const int32_t nrows = ne1*ne2*ne3;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the head vector
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
|
||||
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
|
||||
if (!h_tmp) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
|
||||
return 0;
|
||||
}
|
||||
|
||||
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
|
||||
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
|
||||
|
||||
[encoder setBuffer:h_tmp offset:0 atIndex:6];
|
||||
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
|
||||
// reduce the results from the workgroups
|
||||
{
|
||||
ggml_metal_kargs_flash_attn_ext_reduce args0 = {
|
||||
nrows,
|
||||
ne20,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline0];
|
||||
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
|
||||
[encoder setBuffer:h_tmp offset:0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
|
||||
}
|
||||
}
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
|
||||
@@ -68,6 +68,11 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg)
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
@@ -974,9 +979,16 @@ kernel void kernel_mul(
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
||||
if (args.ne10 == 1) {
|
||||
const float x = *((device float *)(src1_ptr));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1000,9 +1012,16 @@ kernel void kernel_div(
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
||||
if (args.ne10 == 1) {
|
||||
const float x = 1.0f / *((device float *)(src1_ptr));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1964,14 +1983,15 @@ kernel void kernel_ssm_scan_f32(
|
||||
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
const int64_t i = i0 + i1*nc;
|
||||
const int64_t g = ir / (nh / ng); // repeat_interleave
|
||||
float s0 = s0_buff[i];
|
||||
float s = s_buff[i];
|
||||
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
||||
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
||||
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
|
||||
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
||||
|
||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||
@@ -2079,14 +2099,15 @@ kernel void kernel_ssm_scan_f32_group(
|
||||
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
const int64_t i = i0 + i1*nc;
|
||||
const int64_t g = ir / (nh / ng); // repeat_interleave
|
||||
float s0 = s0_buff[i];
|
||||
float s = s_buff[i];
|
||||
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
||||
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
||||
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
||||
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
|
||||
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
|
||||
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
||||
|
||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||
@@ -3001,7 +3022,6 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
||||
#pragma unroll(r1ptg)
|
||||
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||
sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3186,6 +3206,11 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
|
||||
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
|
||||
typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
|
||||
@@ -4663,6 +4688,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
||||
@@ -4674,6 +4700,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||
@@ -4685,6 +4712,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
||||
@@ -4695,6 +4723,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
||||
@@ -4705,6 +4734,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
||||
@@ -4715,6 +4745,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
||||
@@ -4725,6 +4756,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
||||
@@ -4765,14 +4797,19 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device char * dst,
|
||||
constant uint16_t & nwg,
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
static_assert(DK % 32 == 0, "DK must be divisible by 32");
|
||||
static_assert(DV % 32 == 0, "DV must be divisible by 32");
|
||||
|
||||
const int iq3 = tgpig[2];
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
const short iwg = tgpig[2]%nwg;
|
||||
|
||||
const int iq3 = tgpig[2]/nwg;
|
||||
const int iq2 = tgpig[1];
|
||||
const int iq1 = tgpig[0];
|
||||
|
||||
@@ -4851,7 +4888,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
|
||||
for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
|
||||
const int ic = ic0 + C*sgitg;
|
||||
if (ic >= args.ne11) {
|
||||
break;
|
||||
@@ -4981,7 +5018,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
}
|
||||
}
|
||||
|
||||
if (sinks != q && sgitg == 0) {
|
||||
if (sinks != q && sgitg == 0 && iwg == 0) {
|
||||
const float m = M;
|
||||
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
|
||||
|
||||
@@ -5090,14 +5127,25 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
const float S = ss[0];
|
||||
const int64_t nrows = args.ne3*args.ne2*args.ne1;
|
||||
const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results
|
||||
|
||||
const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
|
||||
|
||||
// interleave the workgroup data
|
||||
for (short i = tiisg; i < DV4; i += NW) {
|
||||
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
|
||||
dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
|
||||
}
|
||||
|
||||
// store S and M
|
||||
if (nwg > 1 && tiisg == 0) {
|
||||
dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
|
||||
dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5187,6 +5235,41 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flas
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
kernel void kernel_flash_attn_ext_reduce(
|
||||
constant ggml_metal_kargs_flash_attn_ext_reduce & args,
|
||||
device const char * htmp,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const uint64_t rid = tgpig;
|
||||
|
||||
const short nwg = 32;
|
||||
const short iwg = tiisg;
|
||||
const short DV = args.ne20;
|
||||
const short DV4 = DV/4;
|
||||
|
||||
device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
|
||||
device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg;
|
||||
device float4 * dst4 = (device float4 *) dst + rid*DV4;
|
||||
|
||||
float S = ss[rid*(2*nwg) + 2*iwg + 0];
|
||||
float M = ss[rid*(2*nwg) + 2*iwg + 1];
|
||||
|
||||
const float m = simd_max(M);
|
||||
const float ms = exp(M - m);
|
||||
|
||||
S = 1.0f/simd_sum(S*ms);
|
||||
|
||||
for (int i = sgitg; i < DV4; i += nwg) {
|
||||
const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
|
||||
|
||||
if (iwg == 0) {
|
||||
dst4[i] = v*S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_set(
|
||||
constant ggml_metal_kargs_set & args,
|
||||
@@ -7474,97 +7557,81 @@ kernel void kernel_mul_mm(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T4>
|
||||
template<short ne20> // n_expert_used
|
||||
kernel void kernel_mul_mm_id_map0(
|
||||
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * hsrc1,
|
||||
device char * htpe,
|
||||
device char * hids,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ide = tgpig[0]; // expert id
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const short ide = tpitg; // expert id
|
||||
|
||||
int n_all = 0;
|
||||
uint32_t n_all = 0;
|
||||
|
||||
device int32_t * ids_i32 = (device int32_t *) (hids);
|
||||
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
|
||||
|
||||
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
|
||||
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
|
||||
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
|
||||
if (i21 + tpitg < args.ne21) {
|
||||
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
|
||||
|
||||
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
|
||||
if (src2_i32[i20] != ide) {
|
||||
continue;
|
||||
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
|
||||
|
||||
#pragma unroll(ne20)
|
||||
for (short i20 = 0; i20 < ne20; i20++) {
|
||||
sids[i20] = src2_i32[i20];
|
||||
}
|
||||
|
||||
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
|
||||
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
|
||||
|
||||
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
|
||||
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
|
||||
}
|
||||
|
||||
++n_all;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short t = 0; t < ntg; t++) {
|
||||
if (i21 + t >= args.ne21) {
|
||||
break;
|
||||
}
|
||||
|
||||
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
|
||||
|
||||
short sel = 0;
|
||||
#pragma unroll(ne20)
|
||||
for (short i20 = 0; i20 < ne20; i20++) {
|
||||
sel += (sids[i20] == ide)*(i20 + 1);
|
||||
}
|
||||
|
||||
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
|
||||
|
||||
n_all += sel > 0;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
device int32_t * tpe_i32 = (device int32_t *) (htpe);
|
||||
tpe_i32[ide] = n_all;
|
||||
}
|
||||
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
|
||||
tpe_u32[ide] = n_all;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
|
||||
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_mul_mm_id_map1(
|
||||
constant ggml_metal_kargs_mul_mm_id_map1 & args,
|
||||
device const char * hdst,
|
||||
device const char * hids,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i20 = tgpig[0]; // used expert
|
||||
const int i21 = tgpig[1]; // token
|
||||
|
||||
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
||||
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
|
||||
|
||||
const int id = ids_i32[i21*args.ne20 + i20];
|
||||
|
||||
const int ide = id / args.neh1;
|
||||
const int idt = id % args.neh1;
|
||||
|
||||
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
|
||||
dst_f32x4[i0] = hdst_f32x4[i0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * tpe,
|
||||
device const char * htpe,
|
||||
device const char * hids,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||
@@ -7572,19 +7639,20 @@ kernel void kernel_mul_mm_id(
|
||||
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
const int im = tgpig.z;
|
||||
const int im = tgpig.z; // expert
|
||||
|
||||
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
|
||||
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
|
||||
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
||||
|
||||
const int neh1 = tpe_i32[im];
|
||||
const int32_t neh1 = tpe_u32[im];
|
||||
|
||||
if (r1*BLOCK_SIZE_N >= neh1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
@@ -7600,20 +7668,23 @@ kernel void kernel_mul_mm_id(
|
||||
|
||||
short il = (tiitg % THREAD_PER_ROW);
|
||||
|
||||
const int i12 = im%args.neh12;
|
||||
const int i13 = im/args.neh12;
|
||||
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
|
||||
|
||||
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const short i11 = (id % args.ne20) % args.ne11;
|
||||
const short i12 = (id / args.ne20);
|
||||
const short i13 = 0;
|
||||
|
||||
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
|
||||
const short offset1 = il/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0
|
||||
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
||||
|
||||
device const half * y = (device const half *)(src1
|
||||
+ args.nbh13*i13
|
||||
+ args.nbh12*i12
|
||||
+ args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
|
||||
+ args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
device const float * y = (device const float *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*i11
|
||||
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
@@ -7629,7 +7700,7 @@ kernel void kernel_mul_mm_id(
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
}
|
||||
|
||||
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
|
||||
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
@@ -7665,43 +7736,38 @@ kernel void kernel_mul_mm_id(
|
||||
}
|
||||
}
|
||||
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
|
||||
device float * C = (device float *) dst +
|
||||
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
||||
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
|
||||
}
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
||||
|
||||
#pragma unroll(8)
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short j = sgitg; j < n_cols; j += 4) {
|
||||
const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
|
||||
|
||||
const short ide = id % args.ne20;
|
||||
const short idt = id / args.ne20;
|
||||
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = tiisg;
|
||||
for (; i < n_rows/4; i += 32) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = 0;
|
||||
for (; i < n_rows/4; i++) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
i *= 4;
|
||||
for (; i < n_rows; i++) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
i = (4*(n_rows/4)) + tiisg;
|
||||
for (; i < n_rows; i += 32) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -420,9 +420,9 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_clamp;
|
||||
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
|
||||
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
|
||||
cl_kernel kernel_norm;
|
||||
cl_kernel kernel_norm, kernel_norm_mul_add;
|
||||
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
|
||||
cl_kernel kernel_group_norm;
|
||||
cl_kernel kernel_group_norm, kernel_group_norm_mul_add;
|
||||
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
|
||||
cl_kernel kernel_soft_max, kernel_soft_max_4;
|
||||
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
|
||||
@@ -1161,7 +1161,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
backend_ctx->program_norm =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
@@ -1487,7 +1488,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
backend_ctx->program_group_norm =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
@@ -2498,12 +2500,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
||||
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
||||
return false;
|
||||
}
|
||||
} else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
|
||||
const ggml_tensor *norm = cgraph->nodes[node_idx];
|
||||
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
|
||||
const ggml_tensor *add = cgraph->nodes[node_idx+2];
|
||||
const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0];
|
||||
const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
|
||||
|
||||
// norm fusion only supports F32
|
||||
if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (norm->src[0]->ne[0] % 4 != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
|
||||
return false;
|
||||
}
|
||||
} else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) {
|
||||
const ggml_tensor *gn = cgraph->nodes[node_idx];
|
||||
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
|
||||
const ggml_tensor *add = cgraph->nodes[node_idx+2];
|
||||
const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0];
|
||||
const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0];
|
||||
|
||||
if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
|
||||
static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
|
||||
static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
|
||||
|
||||
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
@@ -2520,6 +2557,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
|
||||
ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
|
||||
ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
@@ -2647,8 +2694,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_RMS_NORM:
|
||||
return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_REPEAT:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
|
||||
case GGML_OP_PAD:
|
||||
@@ -2728,10 +2776,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
if (op->src[4]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * q = op->src[0];
|
||||
const ggml_tensor * k = op->src[1];
|
||||
const ggml_tensor * v = op->src[2];
|
||||
@@ -5038,6 +5082,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
|
||||
GGML_ASSERT(norm_tensor && mul_tensor && add_tensor);
|
||||
|
||||
const ggml_tensor * src0 = norm_tensor->src[0];
|
||||
const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
|
||||
const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
|
||||
const ggml_tensor * dst = add_tensor;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offset2 = extra2->offset + src2->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, norm_tensor->op_params, sizeof(float));
|
||||
|
||||
const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
|
||||
const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
|
||||
const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3];
|
||||
const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
|
||||
const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3];
|
||||
const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3];
|
||||
const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3];
|
||||
|
||||
size_t sgs;
|
||||
if (backend_ctx->gpu_family == ADRENO) sgs = 64;
|
||||
else if (backend_ctx->gpu_family == INTEL) sgs = 32;
|
||||
else GGML_ASSERT(false && "Unsupported GPU");
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_norm_mul_add;
|
||||
|
||||
int nth = sgs;
|
||||
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
|
||||
while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2;
|
||||
nth = MIN(nth, max_workgroup_size);
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t lws[] = {(size_t)nth, 1, 1};
|
||||
size_t num_subgroups = (nth + sgs - 1) / sgs;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20));
|
||||
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21));
|
||||
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22));
|
||||
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23));
|
||||
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21));
|
||||
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22));
|
||||
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23));
|
||||
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps));
|
||||
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL));
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst);
|
||||
}
|
||||
|
||||
static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
|
||||
GGML_ASSERT(gn_tensor && mul_tensor && add_tensor);
|
||||
|
||||
const ggml_tensor * src0 = gn_tensor->src[0];
|
||||
const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0];
|
||||
const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0];
|
||||
const ggml_tensor * dst = add_tensor;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offset2 = extra2->offset + src2->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
int groups;
|
||||
float eps;
|
||||
memcpy(&groups, gn_tensor->op_params, sizeof(int));
|
||||
memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float));
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add;
|
||||
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
|
||||
int ne = ggml_nelements(src0);
|
||||
int group_size = ne / groups;
|
||||
|
||||
size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) };
|
||||
size_t gws[] = { (size_t)groups * lws[0] };
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps));
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -5583,6 +5761,7 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
|
||||
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
|
||||
const ggml_tensor * v = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
GGML_ASSERT(q->extra);
|
||||
GGML_ASSERT(k->extra);
|
||||
GGML_ASSERT(v->extra);
|
||||
@@ -5590,6 +5769,9 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
|
||||
if (mask) {
|
||||
GGML_ASSERT(mask->extra);
|
||||
}
|
||||
if (sinks) {
|
||||
GGML_ASSERT(sinks->extra);
|
||||
}
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
@@ -5631,6 +5813,7 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
|
||||
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
|
||||
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
|
||||
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
|
||||
ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;
|
||||
|
||||
cl_ulong offset_q = extra_q->offset + q->view_offs;
|
||||
cl_ulong offset_k = extra_k->offset + k->view_offs;
|
||||
@@ -5638,6 +5821,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
|
||||
cl_ulong offset_o = extra_o->offset + dst->view_offs;
|
||||
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
|
||||
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
|
||||
cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;
|
||||
cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;
|
||||
|
||||
const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
|
||||
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
|
||||
@@ -5692,6 +5877,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
|
||||
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer));
|
||||
CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));
|
||||
|
||||
if (n_q == 1) {
|
||||
const size_t wg_size = 64;
|
||||
|
||||
@@ -49,7 +49,9 @@ __kernel void flash_attn_f16(
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
const int mask_ne3,
|
||||
const global void* sinks_void,
|
||||
const ulong sinks_offset
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int block_q_idx = get_group_id(0);
|
||||
@@ -171,6 +173,20 @@ __kernel void flash_attn_f16(
|
||||
}
|
||||
|
||||
if (my_query_row < n_q) {
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
|
||||
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
|
||||
}
|
||||
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
@@ -214,7 +230,9 @@ __kernel void flash_attn_f16_q1(
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
const int mask_ne3,
|
||||
const global void* sinks_void,
|
||||
const ulong sinks_offset
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
@@ -247,7 +265,12 @@ __kernel void flash_attn_f16_q1(
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
const global ACC_TYPE* sinks_ptr = NULL;
|
||||
if (sinks_void != NULL) {
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
@@ -320,7 +343,11 @@ __kernel void flash_attn_f16_q1(
|
||||
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
const ACC_TYPE l_final = local_l[0];
|
||||
ACC_TYPE l_final = local_l[0];
|
||||
|
||||
if (sinks_ptr != NULL) {
|
||||
l_final += exp(sinks_ptr[head_idx] - m_final);
|
||||
}
|
||||
|
||||
if (l_final > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_final;
|
||||
|
||||
@@ -49,7 +49,9 @@ __kernel void flash_attn_f32(
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
const int mask_ne3,
|
||||
const global void* sinks_void,
|
||||
const ulong sinks_offset
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int block_q_idx = get_group_id(0);
|
||||
@@ -171,6 +173,20 @@ __kernel void flash_attn_f32(
|
||||
}
|
||||
|
||||
if (my_query_row < n_q) {
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
|
||||
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
|
||||
}
|
||||
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
@@ -214,7 +230,9 @@ __kernel void flash_attn_f32_q1(
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
const int mask_ne3,
|
||||
const global void* sinks_void,
|
||||
const ulong sinks_offset
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
@@ -247,7 +265,12 @@ __kernel void flash_attn_f32_q1(
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
const global ACC_TYPE* sinks_ptr = NULL;
|
||||
if (sinks_void != NULL) {
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
@@ -320,7 +343,11 @@ __kernel void flash_attn_f32_q1(
|
||||
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
const ACC_TYPE l_final = local_l[0];
|
||||
ACC_TYPE l_final = local_l[0];
|
||||
|
||||
if (sinks_ptr != NULL) {
|
||||
l_final += exp(sinks_ptr[head_idx] - m_final);
|
||||
}
|
||||
|
||||
if (l_final > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_final;
|
||||
|
||||
@@ -52,7 +52,9 @@ __kernel void flash_attn_f32_f16(
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
const int mask_ne3,
|
||||
const global void* sinks_void,
|
||||
const ulong sinks_offset
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int block_q_idx = get_group_id(0);
|
||||
@@ -174,6 +176,20 @@ __kernel void flash_attn_f32_f16(
|
||||
}
|
||||
|
||||
if (my_query_row < n_q) {
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
|
||||
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
|
||||
}
|
||||
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
@@ -217,7 +233,9 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
const int mask_ne3,
|
||||
const global void* sinks_void,
|
||||
const ulong sinks_offset
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
@@ -250,7 +268,12 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
const global ACC_TYPE* sinks_ptr = NULL;
|
||||
if (sinks_void != NULL) {
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
||||
@@ -323,7 +346,11 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
const ACC_TYPE l_final = local_l[0];
|
||||
ACC_TYPE l_final = local_l[0];
|
||||
|
||||
if (sinks_ptr != NULL) {
|
||||
l_final += exp(sinks_ptr[head_idx] - m_final);
|
||||
}
|
||||
|
||||
if (l_final > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_final;
|
||||
|
||||
@@ -70,3 +70,52 @@ kernel void kernel_group_norm(
|
||||
dst[j] *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// group_norm_mul_add
|
||||
//------------------------------------------------------------------------------
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_32
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_group_norm_mul_add(
|
||||
global float * src0, ulong offset0,
|
||||
global float * src1, ulong offset1,
|
||||
global float * src2, ulong offset2,
|
||||
global float * dst, ulong offsetd,
|
||||
int ne,
|
||||
int group_size,
|
||||
float eps
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
src1 = (global float *)((global char *)src1 + offset1);
|
||||
src2 = (global float *)((global char *)src2 + offset2);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
|
||||
int start = get_group_id(0) * group_size;
|
||||
int end = start + group_size;
|
||||
if (end > ne) {
|
||||
end = ne;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
float sum_sq = 0.0f;
|
||||
|
||||
for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
|
||||
float val = src0[j];
|
||||
sum += val;
|
||||
sum_sq += val*val;
|
||||
}
|
||||
|
||||
sum = sub_group_reduce_add(sum);
|
||||
sum_sq = sub_group_reduce_add(sum_sq);
|
||||
|
||||
const float mean = sum / group_size;
|
||||
const float var = sum_sq / group_size - mean * mean;
|
||||
const float scale = rsqrt(var + eps);
|
||||
|
||||
for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) {
|
||||
dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,3 +79,83 @@ kernel void kernel_norm(
|
||||
y[i00] = y[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// norm_mul_add
|
||||
//------------------------------------------------------------------------------
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_32
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_norm_mul_add(
|
||||
global char * src0_ptr, ulong src0_offset,
|
||||
global char * src1_ptr, ulong src1_offset,
|
||||
global char * src2_ptr, ulong src2_offset,
|
||||
global char * dst_ptr, ulong dst_offset,
|
||||
int ne00, int ne01, int ne02, int ne03,
|
||||
ulong nb01, ulong nb02, ulong nb03,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
ulong nb11, ulong nb12, ulong nb13,
|
||||
int ne20, int ne21, int ne22, int ne23,
|
||||
ulong nb21, ulong nb22, ulong nb23,
|
||||
ulong nbd1, ulong nbd2, ulong nbd3,
|
||||
float eps,
|
||||
local float2 * sums
|
||||
) {
|
||||
const int i03 = get_group_id(2);
|
||||
const int i02 = get_group_id(1);
|
||||
const int i01 = get_group_id(0);
|
||||
|
||||
global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13);
|
||||
global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23);
|
||||
global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3);
|
||||
|
||||
float p_sum = 0.0f;
|
||||
float p_sum_sq = 0.0f;
|
||||
|
||||
const int n_chunks = ne00 / 4;
|
||||
for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
|
||||
float4 val = x[i00];
|
||||
p_sum += val.x + val.y + val.z + val.w;
|
||||
p_sum_sq += dot(val, val);
|
||||
}
|
||||
|
||||
p_sum = sub_group_reduce_add(p_sum);
|
||||
p_sum_sq = sub_group_reduce_add(p_sum_sq);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq);
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (get_local_id(0) == 0) {
|
||||
float sum = 0.0f;
|
||||
float sum_sq = 0.0f;
|
||||
for (uint i = 0; i < get_num_sub_groups(); ++i) {
|
||||
float2 s = sums[i];
|
||||
sum += s.x;
|
||||
sum_sq += s.y;
|
||||
}
|
||||
|
||||
const float inv_ne00 = 1.0f / (float)ne00;
|
||||
const float mean = sum * inv_ne00;
|
||||
const float variance = mad(-mean, mean, sum_sq * inv_ne00);
|
||||
|
||||
sums[0] = (float2)(mean, rsqrt(variance + eps));
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
const float2 mean_scale = sums[0];
|
||||
const float mean = mean_scale.x;
|
||||
const float scale = mean_scale.y;
|
||||
const float neg_mean_scale = -mean * scale;
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) {
|
||||
const int w_idx = ne10 > 1 ? i00 : 0;
|
||||
const int b_idx = ne20 > 1 ? i00 : 0;
|
||||
const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale);
|
||||
y[i00] = mad(norm_x, w[w_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4364,11 +4364,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
|
||||
#endif
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
|
||||
case GGML_OP_SCALE:
|
||||
return true;
|
||||
case GGML_OP_CONT:
|
||||
@@ -4391,10 +4392,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_ARGSORT:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
|
||||
+1128
-553
File diff suppressed because it is too large
Load Diff
@@ -1,20 +1,34 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
|
||||
const uint num_threads = 256;
|
||||
|
||||
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= p.ne) {
|
||||
continue;
|
||||
@@ -22,8 +36,34 @@ void main() {
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
|
||||
sum_sq += sum*sum;
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.param3 != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -334,6 +334,9 @@ void main() {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] *= Lfrcp[r];
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,10 @@ layout (constant_id = 4) const uint32_t HSV = 32;
|
||||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
|
||||
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
|
||||
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
||||
const uint32_t HSV_pad = (HSV + 15) & ~15;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
@@ -46,14 +46,14 @@ const uint32_t MatBc = 16;
|
||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||
|
||||
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
||||
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 ksh[Bc * kshstride];
|
||||
|
||||
shared float slope[Br];
|
||||
@@ -74,6 +74,21 @@ void main() {
|
||||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
|
||||
if ((HSK % 16) != 0) {
|
||||
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Br * qstride) {
|
||||
Qf[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Bc * kshstride) {
|
||||
ksh[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
@@ -151,14 +166,14 @@ void main() {
|
||||
}
|
||||
barrier();
|
||||
|
||||
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
||||
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
||||
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
@@ -358,6 +373,9 @@ void main() {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -104,16 +104,16 @@ void main() {
|
||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
|
||||
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
|
||||
Qf16 *= float16_t(p.scale);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
||||
|
||||
@@ -140,10 +140,10 @@ void main() {
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
|
||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
|
||||
S = coopMatMulAdd(Qf16, K_T, S);
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
@@ -208,31 +208,31 @@ void main() {
|
||||
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
||||
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
|
||||
|
||||
L = eM*L + rowsum;
|
||||
|
||||
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
||||
// multiply rather than matrix multiply it has the diagonal element smeared
|
||||
// across the row
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
|
||||
|
||||
// resize eM by using smear/reduce
|
||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
// multiply with fp16 accumulation, then add to O.
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
PV = coopMatMulAdd(P_A, V, PV);
|
||||
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
@@ -243,16 +243,16 @@ void main() {
|
||||
return;
|
||||
}
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
|
||||
|
||||
// resize L by using smear/reduce
|
||||
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
|
||||
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
|
||||
|
||||
// resize M by using smear/reduce
|
||||
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
@@ -283,9 +283,13 @@ void main() {
|
||||
|
||||
O = Ldiag*O;
|
||||
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
if (p.gqa_ratio > 1) {
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
} else {
|
||||
@@ -295,6 +299,6 @@ void main() {
|
||||
// permute dimensions
|
||||
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
||||
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +111,10 @@ void main() {
|
||||
}
|
||||
}
|
||||
O *= L;
|
||||
|
||||
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
|
||||
O = clamp(O, -FLT_MAX, FLT_MAX);
|
||||
|
||||
data_d[iq3 * D * N + D * n + d] = O;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,27 +7,36 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint i00 = gl_GlobalInvocationID.x;
|
||||
const uint i10 = gl_GlobalInvocationID.y;
|
||||
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
||||
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
||||
|
||||
if (i00 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
uint gid_z = gl_GlobalInvocationID.z;
|
||||
while (gid_z < p.ne11 * p.ne12) {
|
||||
uint gid_y = gl_GlobalInvocationID.y;
|
||||
while (gid_y < p.ne10) {
|
||||
const uint i10 = gid_y;
|
||||
const uint i11 = gid_z / p.ne12;
|
||||
const uint i12 = gid_z % p.ne12;
|
||||
|
||||
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
|
||||
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
||||
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
||||
#else
|
||||
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
|
||||
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
|
||||
#endif
|
||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||
data_d[d_offset + i00] = D_TYPE(v);
|
||||
data_d[d_offset + i00] = D_TYPE(v);
|
||||
#else
|
||||
data_d[d_offset + i00] = D_TYPE(v);
|
||||
data_d[d_offset + i00] = D_TYPE(v);
|
||||
#endif
|
||||
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,9 +10,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint i00 = (gl_GlobalInvocationID.x)*2;
|
||||
const uint i10 = gl_GlobalInvocationID.y;
|
||||
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
||||
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
||||
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
@@ -22,20 +19,33 @@ void main() {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
uint gid_z = gl_GlobalInvocationID.z;
|
||||
while (gid_z < p.ne11 * p.ne12) {
|
||||
uint gid_y = gl_GlobalInvocationID.y;
|
||||
while (gid_y < p.ne10) {
|
||||
const uint i10 = gid_y;
|
||||
const uint i11 = gid_z / p.ne12;
|
||||
const uint i12 = gid_z % p.ne12;
|
||||
|
||||
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
|
||||
const uint ib = a_offset + i00/QUANT_K; // block index
|
||||
const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
|
||||
const uint iybs = i00 - i00%QUANT_K; // dst block start index
|
||||
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
|
||||
vec2 v = dequantize(ib, iqs, 0);
|
||||
const vec2 dm = get_dm(ib, 0);
|
||||
v = v * dm.x + dm.y;
|
||||
const uint ib = a_offset + i00/QUANT_K; // block index
|
||||
const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
|
||||
const uint iybs = i00 - i00%QUANT_K; // dst block start index
|
||||
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||
|
||||
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
|
||||
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
|
||||
vec2 v = dequantize(ib, iqs, 0);
|
||||
const vec2 dm = get_dm(ib, 0);
|
||||
v = v * dm.x + dm.y;
|
||||
|
||||
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
|
||||
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
|
||||
|
||||
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_8bit_storage : require
|
||||
#if USE_SUBGROUP_ADD
|
||||
|
||||
#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||
#endif
|
||||
@@ -12,10 +13,19 @@
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
#ifndef MMQ
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#else
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
#ifdef B_TYPE_VEC2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||
#endif
|
||||
#ifdef B_TYPE_VEC4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -92,6 +102,23 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
layout (constant_id = 2) const uint NUM_COLS = 1;
|
||||
|
||||
#ifdef USE_SUBGROUP_ADD_NO_SHMEM
|
||||
void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
temp[j][n] = subgroupAdd(temp[j][n]);
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
||||
@@ -152,3 +179,4 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
#extension GL_EXT_integer_dot_product : require
|
||||
|
||||
#define MMQ
|
||||
#define B_TYPE block_q8_1_x4
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#define K_PER_ITER 8
|
||||
|
||||
#include "mul_mmq_funcs.comp"
|
||||
|
||||
uint a_offset, b_offset, d_offset;
|
||||
|
||||
int32_t cache_b_qs[2];
|
||||
vec2 cache_b_ds;
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
|
||||
|
||||
// Preload data_b block
|
||||
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
|
||||
const uint b_qs_idx = tid % 4;
|
||||
const uint b_block_idx_outer = b_block_idx / 4;
|
||||
const uint b_block_idx_inner = b_block_idx % 4;
|
||||
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
|
||||
|
||||
#if QUANT_R == 2
|
||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
|
||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
|
||||
#else
|
||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
|
||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
|
||||
#endif
|
||||
|
||||
uint ibi = first_row*p.ncols;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
|
||||
ibi += p.ncols;
|
||||
|
||||
int32_t q_sum = 0;
|
||||
#if QUANT_R == 2
|
||||
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs.x,
|
||||
cache_b_qs[0]);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs.y,
|
||||
cache_b_qs[1]);
|
||||
#else
|
||||
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||
cache_b_qs[0]);
|
||||
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||
cache_b_qs[1]);
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
|
||||
#else
|
||||
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
a_offset /= QUANT_K;
|
||||
b_offset /= QUANT_K_Q8_1;
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
temp[j][n] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
||||
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
|
||||
num_iters++;
|
||||
}
|
||||
int unroll_count = 4;
|
||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
uint i = 0;
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
#if K_PER_ITER == 2
|
||||
if ((p.ncols & 1) != 0 &&
|
||||
unrolled_iters == num_iters &&
|
||||
unrolled_iters > 0) {
|
||||
unrolled_iters -= unroll_count;
|
||||
}
|
||||
#endif
|
||||
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
while (i < num_iters) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||
compute_outputs(first_row, NUM_ROWS);
|
||||
} else {
|
||||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,9 @@
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
@@ -103,16 +106,79 @@ layout (constant_id = 10) const uint WARP = 32;
|
||||
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
uint _ne1;
|
||||
#ifdef COOPMAT
|
||||
shared uint _ne1_sh;
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
@@ -177,51 +243,20 @@ void main() {
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#ifdef COOPMAT
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
row_ids[_ne1] = u16vec2(ii0, ii1);
|
||||
if (_ne1 >= ic * BN) {
|
||||
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
@@ -767,7 +802,7 @@ void main() {
|
||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||
#if LOAD_VEC_B == 8
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
@@ -783,7 +818,7 @@ void main() {
|
||||
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
||||
#elif LOAD_VEC_B == 4
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
@@ -802,7 +837,7 @@ void main() {
|
||||
#else
|
||||
const uint row_i = ic * BN + loadc_b + l;
|
||||
if (row_i < _ne1 && block + loadr_b < end_k) {
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
||||
} else {
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
@@ -856,6 +891,20 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
|
||||
[[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
|
||||
sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
const uint dr = ir * BM + warp_r * WM;
|
||||
const uint dc = ic * BN + warp_c * WN;
|
||||
|
||||
@@ -873,7 +922,7 @@ void main() {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
|
||||
if (dr + cm_row * TM + store_r < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
@@ -923,7 +972,7 @@ void main() {
|
||||
const uint row_i = dc_warp + cc;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
#ifdef MUL_MAT_ID
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "utils.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
@@ -92,14 +93,15 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
shared u16vec4 row_ids[4096];
|
||||
shared u16vec4 row_ids[BN];
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
B_TYPE b[];
|
||||
};
|
||||
|
||||
uint _ne1;
|
||||
shared uint _ne1_sh;
|
||||
layout (constant_id = 5) const uint subgroup_size = 32;
|
||||
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
||||
|
||||
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
@@ -109,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
|
||||
return B_TYPE(0.0);
|
||||
}
|
||||
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
|
||||
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
|
||||
|
||||
return ret;
|
||||
@@ -121,13 +123,74 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
||||
uint dc = ic * BN + c;
|
||||
|
||||
if (dr < p.M && dc < _ne1) {
|
||||
uint row_i = dc;
|
||||
uint row_i = c;
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
@@ -157,45 +220,12 @@ void main() {
|
||||
const uint ic = gl_WorkGroupID.y;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
#endif
|
||||
@@ -319,6 +349,10 @@ void main() {
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
|
||||
@@ -358,6 +392,10 @@ void main() {
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
|
||||
@@ -398,6 +436,10 @@ void main() {
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
||||
@@ -414,18 +456,111 @@ void main() {
|
||||
|
||||
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
uint k_iters = (end_k - start_k + BK - 1) / BK;
|
||||
|
||||
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
|
||||
store_scales(tid);
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
if ((block_k % QUANT_K) == 0) {
|
||||
store_scales(tid);
|
||||
}
|
||||
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;
|
||||
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
// Call callback to store each element, remapping row through shared memory
|
||||
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
||||
return;
|
||||
}
|
||||
if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
if ((block_k % QUANT_K) == 0) {
|
||||
store_scales(tid);
|
||||
}
|
||||
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
|
||||
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
// Call callback to store each element, remapping row through shared memory
|
||||
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
store_scales(tid);
|
||||
if (block_k + BK < end_k) {
|
||||
if ((block_k % QUANT_K) == 0) {
|
||||
store_scales(tid);
|
||||
}
|
||||
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
@@ -455,6 +590,9 @@ void main() {
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
|
||||
|
||||
@@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
|
||||
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
||||
#endif
|
||||
|
||||
#define LOAD_VEC_A (4 * QUANT_R)
|
||||
#define LOAD_VEC_B 4
|
||||
#define LOAD_VEC_B 16
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
@@ -270,15 +270,22 @@ void main() {
|
||||
const uint iqs = idx & 0x7;
|
||||
#else
|
||||
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
const uint iqs = loadr_b;
|
||||
#endif
|
||||
|
||||
const uint buf_ib = loadc_b + l;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
|
||||
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
}
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
|
||||
}
|
||||
|
||||
barrier();
|
||||
@@ -349,7 +356,7 @@ void main() {
|
||||
cache_b_qs[cc * (BK / 4) + idx_k]);
|
||||
}
|
||||
|
||||
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
|
||||
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) {
|
||||
data_a[ib].qs[iqs * 2 + 1]));
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_nonuniform_qualifier : enable
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "rte.comp"
|
||||
#include "types.comp"
|
||||
@@ -14,11 +18,18 @@ layout (push_constant) uniform parameter2
|
||||
uint ne20; uint ne21; uint ne22; uint ne23;
|
||||
|
||||
// strides for srcs+dst
|
||||
uint nb[8][4];
|
||||
uint nb[12][4];
|
||||
|
||||
uint rms_partials;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
||||
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
||||
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
|
||||
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
||||
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
||||
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
|
||||
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
|
||||
|
||||
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
|
||||
|
||||
layout(constant_id = 0) const uint num_srcs = 2;
|
||||
|
||||
@@ -42,14 +53,22 @@ const uint num_threads = 256;
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= ne) {
|
||||
continue;
|
||||
@@ -61,8 +80,32 @@ void main() {
|
||||
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
|
||||
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
|
||||
}
|
||||
sum_sq += sum*sum;
|
||||
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.rms_partials != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -3,6 +3,15 @@
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#ifdef USE_SUBGROUPS
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
|
||||
#define INVOCATION_ID gl_SubgroupInvocationID.x
|
||||
#else
|
||||
#define INVOCATION_ID gl_LocalInvocationID.x
|
||||
#endif
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne;
|
||||
@@ -14,13 +23,19 @@ layout(constant_id = 0) const uint GROUP_SIZE = 32;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {vec4 data_a[];};
|
||||
#ifndef QBLOCK_X4
|
||||
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
|
||||
#else
|
||||
layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
|
||||
#endif
|
||||
|
||||
#ifndef USE_SUBGROUPS
|
||||
shared float shmem[GROUP_SIZE];
|
||||
#endif
|
||||
|
||||
void quantize() {
|
||||
const uint wgid = gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint tid = INVOCATION_ID;
|
||||
|
||||
// Each thread handles a vec4, so 8 threads handle a block
|
||||
const uint blocks_per_group = GROUP_SIZE / 8;
|
||||
@@ -30,9 +45,19 @@ void quantize() {
|
||||
const uint ib = wgid * blocks_per_group + block_in_wg;
|
||||
const uint iqs = tid % 8;
|
||||
|
||||
#ifndef QBLOCK_X4
|
||||
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
const uint ibx4_outer = ib / 4;
|
||||
const uint ibx4_inner = ib % 4;
|
||||
|
||||
const uint required_x4_blocks = (p.ne + 127) / 128;
|
||||
if (ibx4_outer >= required_x4_blocks) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const uint a_idx = ib * 8 + iqs;
|
||||
|
||||
@@ -40,7 +65,9 @@ void quantize() {
|
||||
const vec4 abs_vals = abs(vals);
|
||||
|
||||
// Find absolute max for each block
|
||||
shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
||||
const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
||||
#ifndef USE_SUBGROUPS
|
||||
shmem[tid] = thread_max;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
|
||||
if (iqs < s) {
|
||||
@@ -50,14 +77,28 @@ void quantize() {
|
||||
}
|
||||
|
||||
const float amax = shmem[block_in_wg * 8];
|
||||
#else
|
||||
const float amax = subgroupClusteredMax(thread_max, 8);
|
||||
#endif
|
||||
|
||||
const float d = amax / 127.0;
|
||||
const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
|
||||
vals = round(vals * d_inv);
|
||||
|
||||
#ifndef QBLOCK_X4
|
||||
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
|
||||
#else
|
||||
data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals)));
|
||||
#endif
|
||||
|
||||
#ifndef USE_SUBGROUPS
|
||||
barrier();
|
||||
#endif
|
||||
|
||||
// Calculate the sum for each block
|
||||
shmem[tid] = vals.x + vals.y + vals.z + vals.w;
|
||||
const float thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
#ifndef USE_SUBGROUPS
|
||||
shmem[tid] = thread_sum;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
|
||||
if (iqs < s) {
|
||||
@@ -65,10 +106,19 @@ void quantize() {
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#else
|
||||
const float sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
#endif
|
||||
if (iqs == 0) {
|
||||
#ifndef USE_SUBGROUPS
|
||||
const float sum = shmem[tid];
|
||||
#endif
|
||||
|
||||
#ifndef QBLOCK_X4
|
||||
data_b[ib].ds = f16vec2(vec2(d, sum * d));
|
||||
#else
|
||||
data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
void rms_norm(uint num_iters) {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
@@ -30,38 +30,76 @@ void main() {
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
sum[tid] += xi * xi;
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
FLOAT_TYPE xi = FLOAT_TYPE(0);
|
||||
if (col < ncols) {
|
||||
xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
}
|
||||
sum += xi * xi;
|
||||
}
|
||||
|
||||
sumsh[tid] = sum;
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
sum += sumsh[tid + s];
|
||||
sumsh[tid] = sum;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
sum = sumsh[0];
|
||||
|
||||
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
// instantiate the rms_norm function for several different
|
||||
// dimensions, to allow loop unrolling
|
||||
uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if (num_blocks > 32) {
|
||||
rms_norm(num_blocks);
|
||||
} else if (num_blocks > 16) {
|
||||
rms_norm(32);
|
||||
} else if (num_blocks > 8) {
|
||||
rms_norm(16);
|
||||
} else if (num_blocks > 4) {
|
||||
rms_norm(8);
|
||||
} else if (num_blocks == 4) {
|
||||
rms_norm(4);
|
||||
} else if (num_blocks == 3) {
|
||||
rms_norm(3);
|
||||
} else if (num_blocks == 2) {
|
||||
rms_norm(2);
|
||||
} else if (num_blocks == 1) {
|
||||
rms_norm(1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_binary_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
#define BLOCK_SIZE 128
|
||||
|
||||
layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};
|
||||
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
|
||||
const uint row = 0;
|
||||
const uint channel = gl_WorkGroupID.y;
|
||||
const uint samp = gl_WorkGroupID.z;
|
||||
// The work is split across multiple workgroups in the x dimension. Each invocation
|
||||
// processes one element
|
||||
const uint tid = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint stride_row = p.nb01;
|
||||
const uint stride_channel = p.nb02;
|
||||
const uint stride_sample = p.nb03;
|
||||
|
||||
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
uint32_t num_partials = p.param3;
|
||||
for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {
|
||||
sum += partial_sums[i];
|
||||
}
|
||||
sum = subgroupAdd(sum);
|
||||
|
||||
uint col = tid;
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
@@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_cols;
|
||||
uint ne01, ne02;
|
||||
uint nb01, nb02, nb03;
|
||||
uint nb11, nb12, nb13;
|
||||
float weight;
|
||||
uint misalign_offsets;
|
||||
uint ne0_12mp, ne0_12L;
|
||||
uint ne0_1mp, ne0_1L;
|
||||
} p;
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
const float weight = p.weight;
|
||||
|
||||
tmp[col] = FLOAT_TYPE(0.0f);
|
||||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||
|
||||
for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
|
||||
tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
|
||||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||
|
||||
tmp[col] = FLOAT_TYPE(0.0);
|
||||
|
||||
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {
|
||||
tmp[col] += FLOAT_TYPE(data_a[src_idx + i]);
|
||||
}
|
||||
|
||||
barrier();
|
||||
@@ -32,6 +65,6 @@ void main() {
|
||||
}
|
||||
|
||||
if (col == 0) {
|
||||
data_d[row] = D_TYPE(tmp[0]);
|
||||
data_d[dst_idx] = D_TYPE(tmp[0] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,6 +207,18 @@ struct block_q8_1_packed32
|
||||
int32_t qs[8];
|
||||
};
|
||||
|
||||
// 4 blocks in one to allow 16-byte/128-bit alignment and loads
|
||||
struct block_q8_1_x4
|
||||
{
|
||||
f16vec2 ds[4];
|
||||
int32_t qs[32];
|
||||
};
|
||||
struct block_q8_1_x4_packed128
|
||||
{
|
||||
f16vec2 ds[4];
|
||||
ivec4 qs[8];
|
||||
};
|
||||
|
||||
// K-quants
|
||||
#define QUANT_K_Q2_K 256
|
||||
|
||||
|
||||
@@ -68,6 +68,12 @@ const std::vector<std::string> type_names = {
|
||||
"bf16",
|
||||
};
|
||||
|
||||
enum MatMulIdType {
|
||||
NONE,
|
||||
DEFAULT,
|
||||
SUBGROUP,
|
||||
};
|
||||
|
||||
namespace {
|
||||
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
@@ -200,6 +206,22 @@ bool string_ends_with(const std::string& str, const std::string& suffix) {
|
||||
return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
|
||||
}
|
||||
|
||||
bool is_quantized_type(const std::string& type_name) {
|
||||
return type_name != "f32" && type_name != "f16" && type_name != "bf16";
|
||||
}
|
||||
|
||||
bool is_legacy_quant(const std::string& type_name) {
|
||||
return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0";
|
||||
}
|
||||
|
||||
bool is_k_quant(const std::string& type_name) {
|
||||
return string_ends_with(type_name, "_k");
|
||||
}
|
||||
|
||||
bool is_iq_quant(const std::string& type_name) {
|
||||
return string_starts_with(type_name, "iq");
|
||||
}
|
||||
|
||||
static const char path_separator = '/';
|
||||
|
||||
std::string join_paths(const std::string& path1, const std::string& path2) {
|
||||
@@ -293,7 +315,7 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
|
||||
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
|
||||
}
|
||||
|
||||
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
||||
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
||||
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
||||
@@ -303,9 +325,13 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
};
|
||||
std::string shader_name = "matmul";
|
||||
|
||||
if (matmul_id) {
|
||||
if (matmul_id_type == MatMulIdType::DEFAULT) {
|
||||
base_dict["MUL_MAT_ID"] = "1";
|
||||
shader_name = "matmul_id";
|
||||
} else if (matmul_id_type == MatMulIdType::SUBGROUP) {
|
||||
base_dict["MUL_MAT_ID"] = "1";
|
||||
base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
|
||||
shader_name = "matmul_id_subgroup";
|
||||
}
|
||||
|
||||
if (fp16) {
|
||||
@@ -313,6 +339,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
}
|
||||
|
||||
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
if (f16acc) {
|
||||
base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
|
||||
}
|
||||
|
||||
if (coopmat) {
|
||||
base_dict["COOPMAT"] = "1";
|
||||
@@ -389,7 +418,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
||||
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
#endif
|
||||
@@ -401,32 +430,38 @@ void process_shaders() {
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
||||
|
||||
// matmul
|
||||
for (const auto& matmul_id : {false, true}) {
|
||||
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
|
||||
// No coopmats
|
||||
// fp32
|
||||
matmul_shaders(false, matmul_id, false, false, false);
|
||||
matmul_shaders(false, matmul_id_type, false, false, false);
|
||||
|
||||
// fp16, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, false, false, false);
|
||||
matmul_shaders(true, matmul_id, false, false, true);
|
||||
matmul_shaders(true, matmul_id_type, false, false, false);
|
||||
matmul_shaders(true, matmul_id_type, false, false, true);
|
||||
|
||||
if (matmul_id_type != MatMulIdType::DEFAULT) {
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
// Coopmat, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, true, false, false);
|
||||
matmul_shaders(true, matmul_id, true, false, true);
|
||||
// Coopmat, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id_type, true, false, false);
|
||||
matmul_shaders(true, matmul_id_type, true, false, true);
|
||||
#endif
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
// Coopmat2, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, false, true, false);
|
||||
matmul_shaders(true, matmul_id, false, true, true);
|
||||
// Coopmat2, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id_type, false, true, false);
|
||||
matmul_shaders(true, matmul_id_type, false, true, true);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// flash attention
|
||||
for (const auto& f16acc : {false, true}) {
|
||||
std::string acctype = f16acc ? "float16_t" : "float";
|
||||
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
|
||||
if (f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "f32") {
|
||||
@@ -437,30 +472,30 @@ void process_shaders() {
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
|
||||
} else {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
}
|
||||
#endif
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -476,8 +511,20 @@ void process_shaders() {
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
|
||||
// mul mat vec with integer dot product
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (is_legacy_quant(tname)) {
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Dequant shaders
|
||||
if (tname != "f16" && tname != "bf16") {
|
||||
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
||||
@@ -503,6 +550,7 @@ void process_shaders() {
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -538,13 +586,15 @@ void process_shaders() {
|
||||
s += std::string(dst_f16 ? "_f16" : "_f32");
|
||||
return s;
|
||||
};
|
||||
for (std::string op : {"add", "sub", "mul", "div"}) {
|
||||
for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
|
||||
for (auto src0_f16 : {false, true}) {
|
||||
for (auto src1_f16 : {false, true}) {
|
||||
for (auto dst_f16 : {false, true}) {
|
||||
for (auto rte : {false, true}) {
|
||||
auto source = op == "add_rms" ? std::string("add") : op;
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
|
||||
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
auto add_rms = op == "add_rms" ? "1" : "0";
|
||||
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -557,7 +607,12 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
||||
string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
|
||||
|
||||
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
|
||||
string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
|
||||
|
||||
string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}});
|
||||
string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}});
|
||||
|
||||
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
@@ -687,7 +742,8 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
@@ -745,7 +801,7 @@ void write_output_files() {
|
||||
}
|
||||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (const char *op : {"add", "sub", "mul", "div"}) {
|
||||
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
|
||||
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
|
||||
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
|
||||
@@ -798,12 +854,15 @@ void write_output_files() {
|
||||
fputs(len.c_str(), src);
|
||||
}
|
||||
|
||||
for (const std::string& btype : {"f16", "f32"}) {
|
||||
for (const std::string& btype : {"f16", "f32", "q8_1"}) {
|
||||
for (const auto& tname : type_names) {
|
||||
fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[2];\n", tname.c_str(), btype.c_str());
|
||||
fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[2];\n", tname.c_str(), btype.c_str());
|
||||
std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data};\n";
|
||||
std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len};\n";
|
||||
if (btype == "q8_1" && !is_legacy_quant(tname)) {
|
||||
continue;
|
||||
}
|
||||
fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str());
|
||||
fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str());
|
||||
std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n";
|
||||
std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n";
|
||||
fputs(data.c_str(), src);
|
||||
fputs(len.c_str(), src);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user