mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 10:07:44 +02:00
Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 225e7a1438 | |||
| ab14019821 | |||
| 64978340b0 | |||
| 6ffd4e9c44 | |||
| e4841d24d3 | |||
| 538cc77f7f | |||
| 5cae766541 | |||
| 4b91d6f71f | |||
| cf91f217f1 | |||
| 79e0b68c17 | |||
| c81f4192f9 | |||
| 4a4f426944 | |||
| ba1ceb3456 | |||
| 10a0351a97 | |||
| 68e37a61a7 | |||
| cbc68be51d | |||
| bdca38376f | |||
| 55c509daf5 | |||
| 9c9e4fc635 | |||
| 494c5899cb | |||
| 0f4c6ec0f1 | |||
| 65a3ebb0aa | |||
| 0d9226763c | |||
| 982e347255 | |||
| 923e3ea2e3 | |||
| e743cddb60 | |||
| 05fec5bd29 | |||
| dcf7f2ea3c | |||
| 84b396e051 | |||
| c31e60647d | |||
| 67eade1bf9 | |||
| 7de5c7cab6 | |||
| 8eff95544e | |||
| 3120413ccd | |||
| 215535701d | |||
| 74bb294591 | |||
| 3e303b1107 | |||
| 0c1df14b5f | |||
| b3ad3a0191 | |||
| 98197e5c98 | |||
| f5e96b368f | |||
| 756aa1020a | |||
| aaa088d87f | |||
| 0d5375d54b | |||
| 576c82eda2 | |||
| 0aedae00e6 | |||
| 6bdda13981 | |||
| 0b8855775c | |||
| 4bb625b713 | |||
| 11ee0fea2a | |||
| a457551332 | |||
| 704bb7a71c | |||
| 435a6d10d6 | |||
| f9a867f592 | |||
| ac44eb6c80 | |||
| a57d1bcb3c | |||
| cb9178f885 | |||
| 4a5686da22 | |||
| 98bab638fb | |||
| 26a48ad699 | |||
| ffd59e7d18 | |||
| 105554595f | |||
| 04655063c4 | |||
| 20b7bf8a32 | |||
| 6efcd65945 | |||
| 699f4392a3 | |||
| 08382869a2 | |||
| bb4f7a9e4e | |||
| b8eeb8741d | |||
| 17a1f0d2d4 | |||
| 8f22dc0a53 | |||
| 53903ae6fa | |||
| 4d0dcd4a06 | |||
| 75c91de6e9 | |||
| 68155c66f0 | |||
| e1a7059053 | |||
| 12f55c302b | |||
| b9c3eefde1 | |||
| 6491d6e4f1 | |||
| e592be1575 |
@@ -342,7 +342,7 @@ jobs:
|
||||
cd build
|
||||
export GGML_VK_VISIBLE_DEVICES=0
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 3600
|
||||
ctest -L main --verbose --timeout 4200
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
name: Update Operations Documentation
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
|
||||
jobs:
|
||||
update-ops-docs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Generate operations documentation to temporary file
|
||||
run: |
|
||||
mkdir -p /tmp/ops_check
|
||||
./scripts/create_ops_docs.py /tmp/ops_check/ops.md
|
||||
|
||||
- name: Check if docs/ops.md matches generated version
|
||||
run: |
|
||||
if ! diff -q docs/ops.md /tmp/ops_check/ops.md; then
|
||||
echo "Operations documentation (docs/ops.md) is not up to date with the backend CSV files."
|
||||
echo "To fix: run ./scripts/create_ops_docs.py and commit the updated docs/ops.md along with your changes"
|
||||
echo "Differences found:"
|
||||
diff docs/ops.md /tmp/ops_check/ops.md || true
|
||||
exit 1
|
||||
fi
|
||||
echo "Operations documentation is up to date."
|
||||
@@ -55,6 +55,17 @@
|
||||
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "x64-linux-gcc", "hidden": true,
|
||||
"cacheVariables": {
|
||||
"CMAKE_C_COMPILER": "gcc",
|
||||
"CMAKE_CXX_COMPILER": "g++"
|
||||
}
|
||||
},
|
||||
{ "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] },
|
||||
{ "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] },
|
||||
{ "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] },
|
||||
{ "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] },
|
||||
|
||||
{ "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] },
|
||||
{ "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] },
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
[](https://github.com/ggml-org/llama.cpp/releases)
|
||||
[](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml)
|
||||
|
||||
[Roadmap](https://github.com/users/ggerganov/projects/7) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml)
|
||||
[Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml) / [ops](https://github.com/ggml-org/llama.cpp/blob/master/docs/ops.md)
|
||||
|
||||
Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++
|
||||
LLM inference in C/C++
|
||||
|
||||
## Recent API changes
|
||||
|
||||
@@ -17,10 +17,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
||||
|
||||
## Hot topics
|
||||
|
||||
- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
|
||||
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated
|
||||
- Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen)
|
||||
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
|
||||
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
|
||||
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
|
||||
- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123
|
||||
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669
|
||||
@@ -134,6 +133,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
|
||||
#### Multimodal
|
||||
|
||||
|
||||
@@ -86,8 +86,7 @@ if (LLAMA_CURL)
|
||||
endif()
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
|
||||
include_directories(${CURL_INCLUDE_DIRS})
|
||||
find_library(CURL_LIBRARY curl REQUIRED)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
|
||||
endif ()
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
@@ -112,13 +111,13 @@ if (LLAMA_LLGUIDANCE)
|
||||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
# v0.7.20 (+ fix to build on GCC 15):
|
||||
GIT_TAG b5b8b64dba11c4e4ee6b1d1450d3a3ae279891e8
|
||||
# v1.0.1:
|
||||
GIT_TAG d795912fedc7d393de740177ea9ea761e7905774
|
||||
PREFIX ${CMAKE_BINARY_DIR}/llguidance
|
||||
SOURCE_DIR ${LLGUIDANCE_SRC}
|
||||
BUILD_IN_SOURCE TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND cargo build --release
|
||||
BUILD_COMMAND cargo build --release --package llguidance
|
||||
INSTALL_COMMAND ""
|
||||
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h
|
||||
UPDATE_COMMAND ""
|
||||
|
||||
@@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.swa_full = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_SWA_FULL"));
|
||||
add_opt(common_arg(
|
||||
{"--kv-unified", "-kvu"},
|
||||
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.kv_unified = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_SPLIT"));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
@@ -2734,6 +2742,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.public_path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
|
||||
add_opt(common_arg(
|
||||
{"--api-prefix"}, "PREFIX",
|
||||
string_format("prefix path the server serves from, without the trailing slash (default: %s)", params.api_prefix.c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.api_prefix = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
|
||||
add_opt(common_arg(
|
||||
{"--no-webui"},
|
||||
string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),
|
||||
@@ -3416,5 +3431,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
// diffusion parameters
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-steps" }, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-eps" }, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-algorithm" }, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
|
||||
params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-alg-temp" }, "F",
|
||||
string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-visual" },
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
|
||||
params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
+13
-6
@@ -1005,15 +1005,21 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias.push_back({i, -INFINITY});
|
||||
}
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
@@ -1157,6 +1163,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
cparams.swa_full = params.swa_full;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
cparams.type_k = params.cache_type_k;
|
||||
cparams.type_v = params.cache_type_v;
|
||||
|
||||
+14
-1
@@ -81,6 +81,7 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_LOOKUP,
|
||||
LLAMA_EXAMPLE_PARALLEL,
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
LLAMA_EXAMPLE_DIFFUSION,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
@@ -177,7 +178,8 @@ struct common_params_sampling {
|
||||
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||
std::set<llama_token> preserved_tokens;
|
||||
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
@@ -217,6 +219,14 @@ struct common_params_vocoder {
|
||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_diffusion {
|
||||
int32_t steps = 64; // number of diffusion steps
|
||||
float eps = 1e-3f; // epsilon for timesteps
|
||||
int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
|
||||
float alg_temp = 0.0f; // algorithm temperature
|
||||
bool visual_mode = false; // show progressive diffusion on screen
|
||||
};
|
||||
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
||||
@@ -268,6 +278,7 @@ struct common_params {
|
||||
struct common_params_sampling sampling;
|
||||
struct common_params_speculative speculative;
|
||||
struct common_params_vocoder vocoder;
|
||||
struct common_params_diffusion diffusion;
|
||||
|
||||
struct common_params_model model;
|
||||
|
||||
@@ -330,6 +341,7 @@ struct common_params {
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool ctx_shift = true; // context shift on inifinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
@@ -370,6 +382,7 @@ struct common_params {
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
std::string api_prefix = ""; // NOLINT
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
|
||||
+918
-32
File diff suppressed because it is too large
Load Diff
@@ -128,6 +128,9 @@ models = [
|
||||
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
|
||||
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
|
||||
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
|
||||
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
|
||||
{"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", },
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
@@ -137,6 +140,13 @@ pre_computed_hashes = [
|
||||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
||||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
|
||||
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
||||
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
||||
]
|
||||
|
||||
|
||||
@@ -222,7 +232,7 @@ for model in models:
|
||||
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:
|
||||
|
||||
src_ifs = ""
|
||||
for model in [*all_models, *pre_computed_hashes]:
|
||||
for model in [*pre_computed_hashes, *all_models]:
|
||||
name = model["name"]
|
||||
tokt = model["tokt"]
|
||||
chkhsh = model.get("chkhsh")
|
||||
@@ -230,11 +240,6 @@ for model in [*all_models, *pre_computed_hashes]:
|
||||
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
|
||||
continue
|
||||
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
|
||||
continue
|
||||
|
||||
# create the tokenizer
|
||||
if chkhsh is not None:
|
||||
# if the model has a pre-computed hash, use it
|
||||
@@ -244,6 +249,12 @@ for model in [*all_models, *pre_computed_hashes]:
|
||||
chkhsh = existing_models[name]
|
||||
else:
|
||||
# otherwise, compute the hash of the tokenizer
|
||||
|
||||
# Skip if the tokenizer folder does not exist or there are other download issues previously
|
||||
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||
logger.warning(f"Directory for tokenizer {name} not found. Skipping...")
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...")
|
||||
if name == "t5":
|
||||
|
||||
@@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv
|
||||
|
||||
### 2. Define the model architecture in `llama.cpp`
|
||||
|
||||
The model params and tensors layout must be defined in `llama.cpp`:
|
||||
1. Define a new `llm_arch`
|
||||
2. Define the tensors layout in `LLM_TENSOR_NAMES`
|
||||
3. Add any non-standard metadata in `llm_load_hparams`
|
||||
4. Create the tensors for inference in `llm_load_tensors`
|
||||
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
|
||||
The model params and tensors layout must be defined in `llama.cpp` source files:
|
||||
1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
|
||||
2. In `src/llama-arch.cpp`:
|
||||
- Add the architecture name to the `LLM_ARCH_NAMES` map.
|
||||
- Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
|
||||
3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
|
||||
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
|
||||
|
||||
NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
|
||||
|
||||
### 3. Build the GGML graph implementation
|
||||
|
||||
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
|
||||
|
||||
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
|
||||
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`.
|
||||
Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor.
|
||||
Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`.
|
||||
Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct.
|
||||
|
||||
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
|
||||
|
||||
|
||||
+95
@@ -0,0 +1,95 @@
|
||||
# GGML Operations
|
||||
|
||||
List of GGML operations and backend support status.
|
||||
|
||||
Legend:
|
||||
- ✅ Fully supported by this backend
|
||||
- 🟡 Partially supported by this backend
|
||||
- ❌ Not supported by this backend
|
||||
|
||||
| Operation | BLAS | CPU | CUDA | Metal |
|
||||
|-----------|------|------|------|------|
|
||||
| ABS | ❌ | ✅ | 🟡 | ❌ |
|
||||
| ACC | ❌ | ✅ | ✅ | ✅ |
|
||||
| ADD | ❌ | ✅ | ✅ | 🟡 |
|
||||
| ADD1 | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | 🟡 |
|
||||
| CONCAT | ❌ | ✅ | 🟡 | ✅ |
|
||||
| CONT | ❌ | ✅ | 🟡 | ✅ |
|
||||
| CONV_2D_DW | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ✅ | ✅ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | 🟡 |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | 🟡 |
|
||||
| DIV | ❌ | ✅ | ✅ | 🟡 |
|
||||
| DUP | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| ELU | ❌ | ✅ | ❌ | 🟡 |
|
||||
| EXP | ❌ | ✅ | 🟡 | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | 🟡 |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | 🟡 |
|
||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | 🟡 |
|
||||
| GELU | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| GELU_ERF | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| GELU_QUICK | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| GET_ROWS | ❌ | ✅ | 🟡 | ✅ |
|
||||
| GET_ROWS_BACK | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ |
|
||||
| HARDSIGMOID | ❌ | ✅ | 🟡 | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | 🟡 | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | 🟡 |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ |
|
||||
| LOG | ❌ | ✅ | ✅ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ |
|
||||
| MUL | ❌ | ✅ | ✅ | 🟡 |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | ✅ | ✅ | ✅ |
|
||||
| NEG | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| NORM | ❌ | ✅ | ✅ | 🟡 |
|
||||
| OPT_STEP_ADAMW | ❌ | ✅ | ✅ | ❌ |
|
||||
| OUT_PROD | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ❌ | ✅ |
|
||||
| POOL_2D | ❌ | ✅ | ✅ | ✅ |
|
||||
| REGLU | ❌ | ✅ | ✅ | 🟡 |
|
||||
| RELU | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| REPEAT | ❌ | ✅ | 🟡 | ✅ |
|
||||
| REPEAT_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM | ❌ | ✅ | ✅ | 🟡 |
|
||||
| RMS_NORM_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| RMS_NORM_MUL | ❌ | ✅ | ✅ | ✅ |
|
||||
| ROPE | ❌ | ✅ | ✅ | ✅ |
|
||||
| ROPE_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ✅ | ✅ | ✅ |
|
||||
| RWKV_WKV7 | ❌ | ✅ | ✅ | ✅ |
|
||||
| SCALE | ❌ | ✅ | ✅ | ✅ |
|
||||
| SET | ❌ | ✅ | ❌ | ✅ |
|
||||
| SET_ROWS | ❌ | 🟡 | ❌ | 🟡 |
|
||||
| SGN | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| SILU | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| SILU_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | 🟡 |
|
||||
| SOFT_MAX | ❌ | ✅ | ✅ | ✅ |
|
||||
| SOFT_MAX_BACK | ❌ | 🟡 | 🟡 | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | 🟡 |
|
||||
| SQRT | ❌ | ✅ | ✅ | 🟡 |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ |
|
||||
| SSM_SCAN | ❌ | ✅ | ✅ | ✅ |
|
||||
| STEP | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | 🟡 |
|
||||
| SUM | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | 🟡 |
|
||||
| TANH | ❌ | ✅ | 🟡 | 🟡 |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ |
|
||||
| UPSCALE | ❌ | ✅ | ✅ | 🟡 |
|
||||
+6534
File diff suppressed because it is too large
Load Diff
+6534
File diff suppressed because it is too large
Load Diff
+6534
File diff suppressed because it is too large
Load Diff
+6534
File diff suppressed because it is too large
Load Diff
@@ -33,6 +33,7 @@ else()
|
||||
add_subdirectory(speculative-simple)
|
||||
add_subdirectory(gen-docs)
|
||||
add_subdirectory(training)
|
||||
add_subdirectory(diffusion)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
add_subdirectory(convert-llama2c-to-ggml)
|
||||
# these examples use the backends directly and cannot be built with dynamic loading
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-diffusion-cli)
|
||||
add_executable(${TARGET} diffusion-cli.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
@@ -0,0 +1,507 @@
|
||||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <limits.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
enum diffusion_alg {
|
||||
DIFFUSION_ALG_ORIGIN = 0,
|
||||
DIFFUSION_ALG_MASKGIT_PLUS = 1,
|
||||
DIFFUSION_ALG_TOPK_MARGIN = 2,
|
||||
DIFFUSION_ALG_ENTROPY = 3,
|
||||
};
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps;
|
||||
float eps;
|
||||
float temperature;
|
||||
float top_p;
|
||||
int32_t top_k;
|
||||
llama_token mask_token_id;
|
||||
enum diffusion_alg algorithm;
|
||||
float alg_temp;
|
||||
diffusion_step_callback_t step_callback;
|
||||
void * step_callback_user_data;
|
||||
int32_t seed;
|
||||
};
|
||||
|
||||
|
||||
static diffusion_params diffusion_default_params() {
|
||||
diffusion_params params = {};
|
||||
params.steps = 64;
|
||||
params.eps = 1e-3f;
|
||||
params.temperature = 0.2f;
|
||||
params.top_p = 0.95f;
|
||||
params.top_k = 0;
|
||||
params.mask_token_id = LLAMA_TOKEN_NULL;
|
||||
params.algorithm = DIFFUSION_ALG_ORIGIN;
|
||||
params.alg_temp = 0.0f;
|
||||
params.step_callback = nullptr;
|
||||
params.step_callback_user_data = nullptr;
|
||||
params.seed = 0;
|
||||
return params;
|
||||
}
|
||||
|
||||
static void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
int32_t max_length,
|
||||
struct diffusion_params params,
|
||||
int32_t & n_generated) {
|
||||
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
std::vector<float> timesteps(params.steps + 1);
|
||||
for (int32_t i = 0; i <= params.steps; i++) {
|
||||
timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
|
||||
}
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(max_length);
|
||||
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(max_length);
|
||||
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(max_length, 0, 1);
|
||||
batch.n_tokens = max_length;
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
|
||||
int64_t time_start = ggml_time_us();
|
||||
for (int32_t step = 0; step < params.steps; step++) {
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
float * raw_logits = llama_get_logits(ctx);
|
||||
if (!raw_logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
float t = timesteps[step];
|
||||
float s = timesteps[step + 1];
|
||||
|
||||
if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
|
||||
float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
/* .data = */ candidates.data(),
|
||||
/* .size = */ (size_t) n_vocab, // Reset size to full vocab
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
/* .data = */ candidates.data(),
|
||||
/* .size = */ candidates.size(),
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float confidence = 0.0f;
|
||||
if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t j = 0; j < cur_p.size; j++) {
|
||||
float prob = cur_p.data[j].p;
|
||||
confidence += prob * logf(prob + epsilon);
|
||||
}
|
||||
} else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
|
||||
confidence = cur_p.data[0].p - cur_p.data[1].p;
|
||||
} else {
|
||||
confidence = cur_p.data[cur_p.selected].p;
|
||||
}
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(confidence, i);
|
||||
}
|
||||
|
||||
int32_t num_transfer =
|
||||
(step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
|
||||
|
||||
if (num_transfer > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
|
||||
for (int32_t pos = 0; pos < max_length; pos++) {
|
||||
float conf_logit = -std::numeric_limits<float>::infinity();
|
||||
|
||||
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
|
||||
if (it != mask_positions.end()) {
|
||||
size_t mask_idx = std::distance(mask_positions.begin(), it);
|
||||
conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
|
||||
}
|
||||
|
||||
conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
/* .data = */ conf_candidates.data(),
|
||||
/* .size = */ conf_candidates.size(),
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < num_transfer; i++) {
|
||||
// Apply distribution sampler to get selected index
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int selected_idx = conf_array.selected;
|
||||
confidences[i].second = conf_candidates[selected_idx].id;
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.alg_temp == 0.0f) {
|
||||
// Deterministic - use confidence order
|
||||
for (int32_t i = 0; i < num_transfer; i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
llama_token token = sampled_tokens[mask_idx];
|
||||
output_tokens[pos] = token;
|
||||
}
|
||||
} else {
|
||||
for (int32_t i = 0; i < num_transfer; i++) {
|
||||
int32_t pos = confidences[i].second;
|
||||
auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
|
||||
if (it != mask_positions.end()) {
|
||||
int32_t mask_idx = std::distance(mask_positions.begin(), it);
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = max_length;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
|
||||
if (!use_chat_template) {
|
||||
return prompt;
|
||||
}
|
||||
|
||||
auto chat_templates = common_chat_templates_init(model, "");
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
common_chat_msg user_msg;
|
||||
user_msg.role = "user";
|
||||
user_msg.content = prompt;
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.messages.push_back(user_msg);
|
||||
|
||||
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
|
||||
|
||||
return result.prompt;
|
||||
}
|
||||
|
||||
struct callback_data {
|
||||
const common_params_diffusion * diff_params;
|
||||
const llama_vocab * vocab;
|
||||
int32_t n_input;
|
||||
};
|
||||
|
||||
static bool diffusion_step_callback(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data) {
|
||||
(void)user_data;
|
||||
|
||||
callback_data * data = static_cast<callback_data *>(user_data);
|
||||
|
||||
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
|
||||
int progress_percent = (step * 100) / total_steps;
|
||||
int progress_bars = (step * 50) / total_steps;
|
||||
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
|
||||
step,
|
||||
total_steps,
|
||||
std::string(progress_bars, '=').c_str(),
|
||||
std::string(50 - progress_bars, ' ').c_str(),
|
||||
progress_percent);
|
||||
};
|
||||
|
||||
if (data->diff_params->visual_mode) {
|
||||
// Visual mode: clear
|
||||
LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
|
||||
|
||||
print_progress_bar(step, total_steps);
|
||||
|
||||
LOG_INF("\n");
|
||||
|
||||
std::string current_text = " ";
|
||||
|
||||
for (int32_t i = data->n_input; i < n_tokens; i++) {
|
||||
std::string token_str;
|
||||
if (tokens[i] != llama_vocab_mask(data->vocab)) {
|
||||
char piece[256];
|
||||
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
|
||||
if (n_chars > 0) {
|
||||
piece[n_chars] = '\0';
|
||||
token_str = piece;
|
||||
}
|
||||
} else {
|
||||
token_str = " ";
|
||||
}
|
||||
|
||||
current_text += token_str;
|
||||
}
|
||||
|
||||
LOG_INF("%s\n", current_text.c_str());
|
||||
} else {
|
||||
print_progress_bar(step, total_steps);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
|
||||
const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
|
||||
alg_names[params.diffusion.algorithm] :
|
||||
"UNKNOWN";
|
||||
|
||||
common_init();
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.n_gpu_layers = params.n_gpu_layers;
|
||||
model_params.devices = params.devices.data();
|
||||
model_params.use_mmap = params.use_mmap;
|
||||
model_params.use_mlock = params.use_mlock;
|
||||
model_params.check_tensors = params.check_tensors;
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
|
||||
if (!model) {
|
||||
LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = params.n_ctx;
|
||||
ctx_params.n_batch = params.n_batch;
|
||||
ctx_params.n_ubatch = params.n_ubatch;
|
||||
ctx_params.flash_attn = params.flash_attn;
|
||||
ctx_params.no_perf = params.no_perf;
|
||||
ctx_params.type_k = params.cache_type_k;
|
||||
ctx_params.type_v = params.cache_type_v;
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
if (!ctx) {
|
||||
LOG_ERR("error: failed to create context\n");
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
|
||||
|
||||
std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
|
||||
/*add special tokens*/ true,
|
||||
/*parse special*/ true);
|
||||
int n_input = input_tokens.size();
|
||||
|
||||
if (n_input >= params.n_ctx) {
|
||||
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
struct diffusion_params ldiff_params = diffusion_default_params();
|
||||
ldiff_params.steps = params.diffusion.steps;
|
||||
ldiff_params.eps = params.diffusion.eps;
|
||||
ldiff_params.temperature = params.sampling.temp;
|
||||
ldiff_params.top_p = params.sampling.top_p;
|
||||
ldiff_params.top_k = params.sampling.top_k;
|
||||
ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
|
||||
ldiff_params.alg_temp = params.diffusion.alg_temp;
|
||||
ldiff_params.seed = params.sampling.seed;
|
||||
|
||||
llama_token mask_token_id = llama_vocab_mask(vocab);
|
||||
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
|
||||
|
||||
LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
|
||||
alg_name);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
|
||||
|
||||
ldiff_params.mask_token_id = mask_token_id;
|
||||
|
||||
callback_data cb_data = { ¶ms.diffusion, vocab, n_input };
|
||||
|
||||
ldiff_params.step_callback = diffusion_step_callback;
|
||||
ldiff_params.step_callback_user_data = &cb_data;
|
||||
|
||||
int32_t n_generated = 0;
|
||||
|
||||
std::vector<llama_token> output_tokens(params.n_ubatch);
|
||||
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
|
||||
ldiff_params, n_generated);
|
||||
|
||||
if (n_generated > 0) {
|
||||
if (params.diffusion.visual_mode) {
|
||||
//clear screen and move cursor to top-left
|
||||
LOG_INF("\033[2J\033[H");
|
||||
}
|
||||
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
|
||||
std::string output_data = common_detokenize(vocab, output_tokens, false);
|
||||
LOG_INF("\n%s\n", output_data.c_str());
|
||||
} else {
|
||||
LOG_INF("Error: diffusion generation failed\n");
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ int main(int argc, char ** argv) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.smpl = common_sampler_init(model, params.sampling);
|
||||
//params.sampling.seed++;
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
@@ -345,7 +346,7 @@ int main(int argc, char ** argv) {
|
||||
client.n_decoded = 0;
|
||||
client.i_batch = batch.n_tokens - 1;
|
||||
|
||||
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur);
|
||||
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);
|
||||
|
||||
g_seq_id += 1;
|
||||
|
||||
|
||||
+14
-1
@@ -495,7 +495,7 @@ extern "C" {
|
||||
GGML_OP_POOL_1D,
|
||||
GGML_OP_POOL_2D,
|
||||
GGML_OP_POOL_2D_BACK,
|
||||
GGML_OP_UPSCALE, // nearest interpolate
|
||||
GGML_OP_UPSCALE,
|
||||
GGML_OP_PAD,
|
||||
GGML_OP_PAD_REFLECT_1D,
|
||||
GGML_OP_ROLL,
|
||||
@@ -1297,6 +1297,19 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
float s);
|
||||
|
||||
// x = s * a + b
|
||||
GGML_API struct ggml_tensor * ggml_scale_bias(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b);
|
||||
|
||||
// b -> view(a,offset,nb1,nb2,3), return modified a
|
||||
GGML_API struct ggml_tensor * ggml_set(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@@ -2090,6 +2090,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
#pragma message("TODO: implement F32, F16, BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_CPY: {
|
||||
@@ -2188,7 +2189,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_CLAMP:
|
||||
@@ -2210,6 +2210,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return true;
|
||||
case GGML_OP_SCALE:
|
||||
float bias;
|
||||
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
|
||||
return bias == 0.0f; // TODO: support bias != 0.0f
|
||||
case GGML_OP_SOFT_MAX:
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
|
||||
+340
-1091
File diff suppressed because it is too large
Load Diff
@@ -4015,6 +4015,9 @@ static void ggml_compute_forward_rms_norm_f32(
|
||||
|
||||
const float scale = 1.0f/sqrtf(mean + eps);
|
||||
|
||||
// if you hit this, likely you got an inf somewhere earlier
|
||||
assert(scale > 0.0f);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
@@ -4643,9 +4646,11 @@ static void ggml_compute_forward_scale_f32(
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
// scale factor
|
||||
float v;
|
||||
memcpy(&v, dst->op_params, sizeof(float));
|
||||
float s; // scale factor
|
||||
float b; // bias
|
||||
|
||||
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@@ -4664,12 +4669,22 @@ static void ggml_compute_forward_scale_f32(
|
||||
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
if (dst->data != src0->data) {
|
||||
// src0 is same shape as dst => same indices
|
||||
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
||||
if (b == 0.0f) {
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
if (dst->data != src0->data) {
|
||||
// src0 is same shape as dst => same indices
|
||||
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
|
||||
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
|
||||
}
|
||||
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
|
||||
}
|
||||
} else {
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
ggml_vec_mad1_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*nb1),
|
||||
(float *) ((char *) src0->data + i1*nb1),
|
||||
s, b);
|
||||
}
|
||||
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -221,6 +221,9 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
|
||||
for (int i = np; i < n; ++i) {
|
||||
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
|
||||
// if you hit this, you are likely running outside the FP range
|
||||
assert(!isnan(sumf) && !isinf(sumf));
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
|
||||
@@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
|
||||
#if defined(GGML_USE_ACCELERATE)
|
||||
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
|
||||
#elif defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
// scalar ; TODO: Write SVE code
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#else
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
||||
GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
|
||||
GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
|
||||
|
||||
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
||||
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = x[i]*s + b;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
||||
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
#if defined(GGML_USE_ACCELERATE)
|
||||
|
||||
@@ -176,17 +176,20 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#endif
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
|
||||
const int id = ggml_cuda_get_device(); \
|
||||
if (!shared_memory_limit_raised[id]) { \
|
||||
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||
shared_memory_limit_raised[id] = true; \
|
||||
} \
|
||||
} while (0)
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
|
||||
const int id = ggml_cuda_get_device(); \
|
||||
if (!shared_memory_limit_raised[id]) { \
|
||||
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||
shared_memory_limit_raised[id] = true; \
|
||||
} \
|
||||
} while (0)
|
||||
#else
|
||||
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
GGML_UNUSED(nbytes); \
|
||||
} while (0)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
|
||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||
|
||||
@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val = 0.0f;
|
||||
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst,
|
||||
const int parallel_blocks) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
||||
dst += D * gridDim.z*blockIdx.x;
|
||||
// Dimension 0: threadIdx.x
|
||||
// Dimension 1: blockIdx.x
|
||||
// Dimension 2: blockIdx.y
|
||||
// Dimension 3: blockIdx.z
|
||||
// Memory layout is permuted with [0, 2, 1, 3]
|
||||
|
||||
const int ne01 = gridDim.x;
|
||||
const int ne02 = gridDim.y;
|
||||
|
||||
const int col = blockIdx.x;
|
||||
const int head = blockIdx.y;
|
||||
const int sequence = blockIdx.z;
|
||||
|
||||
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
||||
|
||||
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
||||
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
||||
dst += j_dst_unrolled * D;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
extern __shared__ float2 meta[];
|
||||
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
||||
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
|
||||
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
dst[tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
@@ -705,8 +723,6 @@ void launch_fattn(
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
GGML_ASSERT(Q->ne[3] == 1);
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
@@ -853,8 +869,8 @@ void launch_fattn(
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
||||
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
||||
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
@@ -869,11 +885,11 @@ void launch_fattn(
|
||||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_combine_results<DV>
|
||||
|
||||
@@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
|
||||
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
||||
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
||||
@@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
int kb0_start = kbc % iter_k;
|
||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
@@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
|
||||
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
@@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
@@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
||||
kqsum_j = warp_reduce_sum((float)kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= __half2half2(kqsum_j);
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
|
||||
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
@@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
@@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val.x /= kqsum_j;
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -299,14 +305,14 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio);
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
@@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
@@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
|
||||
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
@@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = D / WARP_SIZE;
|
||||
@@ -326,24 +330,25 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
||||
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
@@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const half2 * mask2 = (const half2 *) maskh;
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
|
||||
@@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
|
||||
float KQ_rowsum_j;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
@@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
@@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
|
||||
dst[j_dst_unrolled*D + i] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y == 1 || threadIdx.x != 0) {
|
||||
@@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||
}
|
||||
dst_meta_val.y = KQ_rowsum_j;
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
|
||||
dst_meta[j_dst_unrolled] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
@@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
|
||||
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
|
||||
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
||||
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_BF16:
|
||||
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
@@ -210,6 +214,10 @@ void get_rows_cuda(
|
||||
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
ggml_cuda_op_get_rows_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_cuda_op_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
@@ -2299,6 +2303,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_EXP:
|
||||
ggml_cuda_op_exp(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_ELU:
|
||||
ggml_cuda_op_elu(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -3112,6 +3119,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
@@ -3200,6 +3208,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -3214,6 +3224,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
{
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
||||
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_I64;
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
@@ -3333,8 +3350,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_SSM_SCAN: {
|
||||
if (op->src[3]->ne[0] == 1) {
|
||||
// Mamba2
|
||||
// (kernel only supports d_state == 128 && d_head % 16 == 0)
|
||||
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
|
||||
// (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
|
||||
return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
|
||||
} else {
|
||||
// Mamba
|
||||
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
|
||||
@@ -3373,7 +3390,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
@@ -3397,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
|
||||
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
@@ -3415,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
if (op->src[3] && op->src[3]->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
}
|
||||
|
||||
+21
-27
@@ -50,21 +50,19 @@ static __global__ void rope_norm(
|
||||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
const int idst = row_dst*ne0 + i0;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + 0] = x[ix + 0];
|
||||
dst[idst + 1] = x[ix + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
@@ -94,21 +92,19 @@ static __global__ void rope_neox(
|
||||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
const int idst = row_dst*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
@@ -138,21 +134,19 @@ static __global__ void rope_multi(
|
||||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
const int idst = row_dst*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
#include "scale.cuh"
|
||||
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[i] = scale * x[i];
|
||||
dst[i] = scale * x[i] + bias;
|
||||
}
|
||||
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
float bias;
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
|
||||
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
#include "set-rows.cuh"
|
||||
|
||||
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
|
||||
GGML_UNUSED(src_f);
|
||||
GGML_UNUSED(dst_f);
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
|
||||
*dst_h = __float2half(*src_f);
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
|
||||
*dst_b = *src_f;
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
|
||||
*dst_f = *src_f;
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void k_set_rows(
|
||||
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = i / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
|
||||
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
|
||||
|
||||
const src_t* src_elem = src0_row + i00;
|
||||
dst_t* dst_elem = dst_row_ptr + i00;
|
||||
set_rows_1(src_elem, dst_elem);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void set_rows_cuda(
|
||||
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
|
||||
const dim3 grid_size(num_blocks);
|
||||
|
||||
|
||||
const int64_t s01 = nb01/sizeof(src_t);
|
||||
const int64_t s02 = nb02/sizeof(src_t);
|
||||
const int64_t s03 = nb03/sizeof(src_t);
|
||||
const int64_t s10 = nb10/sizeof(int64_t);
|
||||
const int64_t s11 = nb11/sizeof(int64_t);
|
||||
const int64_t s12 = nb12/sizeof(int64_t);
|
||||
const int64_t s1 = nb1/sizeof(dst_t);
|
||||
const int64_t s2 = nb2/sizeof(dst_t);
|
||||
const int64_t s3 = nb3/sizeof(dst_t);
|
||||
|
||||
if (ne_total > 0) {
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I64);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const int64_t * src1_d = (const int64_t *)src1->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (float*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (half*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else if (dst->type == GGML_TYPE_BF16) {
|
||||
set_rows_cuda(
|
||||
src0_d, src1_d, (nv_bfloat16*)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("unsupported type");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_SET_ROWS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -107,8 +107,11 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
if (nc == 4) {
|
||||
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else if (nc == 3) {
|
||||
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 4 now.");
|
||||
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
||||
}
|
||||
} else {
|
||||
if (nc == 4) {
|
||||
@@ -116,8 +119,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else if (nc == 3) {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 4 right now.");
|
||||
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
||||
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
||||
cudaStream_t stream) {
|
||||
const int threads = 128;
|
||||
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||
if (src3_nb1 == sizeof(float)) {
|
||||
// Mamba-2
|
||||
if (d_state == 128) {
|
||||
const int threads = 128;
|
||||
GGML_ASSERT(d_state % threads == 0);
|
||||
// NOTE: can be any power of two between 4 and 64
|
||||
const int splitH = 16;
|
||||
@@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
} else if (d_state == 256) { // Falcon-H1
|
||||
const int threads = 256;
|
||||
// NOTE: can be any power of two between 8 and 64
|
||||
const int splitH = 16;
|
||||
GGML_ASSERT(head_dim % splitH == 0);
|
||||
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
||||
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||
} else {
|
||||
GGML_ABORT("doesn't support d_state!=128.");
|
||||
GGML_ABORT("doesn't support d_state!=(128 or 256).");
|
||||
}
|
||||
} else {
|
||||
const int threads = 128;
|
||||
// Mamba-1
|
||||
GGML_ASSERT(n_head % threads == 0);
|
||||
GGML_ASSERT(head_dim == 1);
|
||||
|
||||
@@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_elu(float x) {
|
||||
return (x > 0.f) ? x : expm1f(x);
|
||||
}
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
@@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_log>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_elu>(ctx, dst);
|
||||
}
|
||||
/* gated ops */
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
|
||||
@@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst,
|
||||
dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
|
||||
}
|
||||
|
||||
static __global__ void upscale_f32_bilinear(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset) {
|
||||
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i10_dst = index % ne10_dst;
|
||||
const int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||
|
||||
const int i02_src = (int)(i12_dst / sf2);
|
||||
const int i03_src = (int)(i13_dst / sf3);
|
||||
|
||||
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
||||
int y0_src = (int)floorf(y_src_f);
|
||||
int y1_src = y0_src + 1;
|
||||
|
||||
y0_src = max(0, min(y0_src, ne01_src - 1));
|
||||
y1_src = max(0, min(y1_src, ne01_src - 1));
|
||||
|
||||
float dy = y_src_f - (float)y0_src;
|
||||
dy = max(0.0f, min(dy, 1.0f));
|
||||
|
||||
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
||||
int x0_src = (int)floorf(x_src_f);
|
||||
int x1_src = x0_src + 1;
|
||||
|
||||
x0_src = max(0, min(x0_src, ne00_src - 1));
|
||||
x1_src = max(0, min(x1_src, ne00_src - 1));
|
||||
|
||||
float dx = x_src_f - (float)x0_src;
|
||||
dx = max(0.0f, min(dx, 1.0f));
|
||||
|
||||
const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
|
||||
const float val_a = *p_a;
|
||||
const float val_b = *p_b;
|
||||
const float val_c = *p_c;
|
||||
const float val_d = *p_d;
|
||||
|
||||
float result = val_a * (1.0f - dx) * (1.0f - dy) +
|
||||
val_b * dx * (1.0f - dy) +
|
||||
val_c * (1.0f - dx) * dy +
|
||||
val_d * dx * dy;
|
||||
|
||||
dst[index] = result;
|
||||
}
|
||||
|
||||
static void upscale_f32_cuda(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
cudaStream_t stream) {
|
||||
int dst_size = ne10 * ne11 * ne12 * ne13;
|
||||
int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
|
||||
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
|
||||
}
|
||||
|
||||
static void upscale_f32_bilinear_cuda(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset, cudaStream_t stream) {
|
||||
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
@@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float sf0 = (float)dst->ne[0]/src0->ne[0];
|
||||
const float sf1 = (float)dst->ne[1]/src0->ne[1];
|
||||
const float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||
const int mode_flags = dst->op_params[0];
|
||||
const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
|
||||
|
||||
float sf0 = (float)dst->ne[0]/src0->ne[0];
|
||||
float sf1 = (float)dst->ne[1]/src0->ne[1];
|
||||
float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
||||
|
||||
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
||||
if (mode == GGML_SCALE_MODE_NEAREST) {
|
||||
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sf0, sf1, sf2, sf3, pixel_offset, stream);
|
||||
}
|
||||
}
|
||||
|
||||
Vendored
+14
-5
@@ -10,9 +10,6 @@
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
@@ -30,7 +27,6 @@
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
@@ -42,7 +38,6 @@
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
@@ -144,6 +139,20 @@
|
||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
|
||||
#define cublasComputeType_t hipblasComputeType_t
|
||||
#define cudaDataType_t hipDataType
|
||||
#else
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define cublasComputeType_t hipblasDatatype_t
|
||||
#define cudaDataType_t hipblasDatatype_t
|
||||
#endif
|
||||
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
||||
|
||||
@@ -173,6 +173,12 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SILU,
|
||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||
GGML_METAL_KERNEL_TYPE_ELU,
|
||||
GGML_METAL_KERNEL_TYPE_ABS,
|
||||
GGML_METAL_KERNEL_TYPE_SGN,
|
||||
GGML_METAL_KERNEL_TYPE_STEP,
|
||||
GGML_METAL_KERNEL_TYPE_HARDSWISH,
|
||||
GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
|
||||
GGML_METAL_KERNEL_TYPE_EXP,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||
@@ -1155,6 +1161,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
||||
@@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_STEP:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
@@ -2256,7 +2274,9 @@ static bool ggml_metal_encode_node(
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(scale));
|
||||
float bias;
|
||||
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
||||
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
|
||||
|
||||
int64_t n = ggml_nelements(dst);
|
||||
|
||||
@@ -2273,6 +2293,7 @@ static bool ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@@ -2436,6 +2457,78 @@ static bool ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_ABS:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_STEP:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXP:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
|
||||
@@ -1014,16 +1014,18 @@ kernel void kernel_scale(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant float & scale,
|
||||
constant float & bias,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * scale;
|
||||
dst[tpig] = src0[tpig] * scale + bias;
|
||||
}
|
||||
|
||||
kernel void kernel_scale_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
constant float & scale,
|
||||
constant float & bias,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * scale;
|
||||
dst[tpig] = src0[tpig] * scale + bias;
|
||||
}
|
||||
|
||||
kernel void kernel_clamp(
|
||||
@@ -1197,6 +1199,51 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_abs(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = fabs(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sgn(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_step(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_hardswish(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_hardsigmoid(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_exp(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_reglu(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
|
||||
@@ -88,6 +88,7 @@ set(GGML_OPENCL_KERNELS
|
||||
rms_norm
|
||||
rope
|
||||
scale
|
||||
set_rows
|
||||
sigmoid
|
||||
silu
|
||||
softmax_4_f32
|
||||
@@ -103,6 +104,7 @@ set(GGML_OPENCL_KERNELS
|
||||
tanh
|
||||
pad
|
||||
repeat
|
||||
mul_mat_f16_f32
|
||||
)
|
||||
|
||||
foreach (K ${GGML_OPENCL_KERNELS})
|
||||
|
||||
@@ -351,6 +351,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program_gemv_noshuffle_general;
|
||||
cl_program program_gemv_noshuffle;
|
||||
cl_program program_get_rows;
|
||||
cl_program program_set_rows;
|
||||
cl_program program_glu;
|
||||
cl_program program_im2col_f16;
|
||||
cl_program program_im2col_f32;
|
||||
@@ -367,6 +368,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program_mul_mv_f16_f32;
|
||||
cl_program program_mul_mv_f32_f32;
|
||||
cl_program program_mul;
|
||||
cl_program program_mul_mat_f16_f32_tiled;
|
||||
cl_program program_div;
|
||||
cl_program program_sub;
|
||||
cl_program program_norm;
|
||||
@@ -412,6 +414,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_soft_max, kernel_soft_max_4;
|
||||
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
|
||||
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
|
||||
cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
|
||||
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
|
||||
cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;
|
||||
cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
|
||||
@@ -420,6 +423,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mat_f16_f32_1row;
|
||||
cl_kernel kernel_mul_mat_f16_f32;
|
||||
cl_kernel kernel_mul_mat_f16_f32_l4;
|
||||
cl_kernel kernel_mul_mat_f16_f32_tiled;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
@@ -529,6 +533,16 @@ struct ggml_backend_opencl_context {
|
||||
fclose(ftrace);
|
||||
}
|
||||
|
||||
size_t get_kernel_workgroup_size(cl_kernel kernel) const {
|
||||
size_t workgroup_size = 0;
|
||||
size_t ret_size = 0;
|
||||
CL_CHECK(
|
||||
clGetKernelWorkGroupInfo(kernel, device, CL_KERNEL_WORK_GROUP_SIZE,
|
||||
sizeof(size_t), &workgroup_size, &ret_size));
|
||||
GGML_ASSERT(sizeof(size_t) == ret_size);
|
||||
return workgroup_size;
|
||||
}
|
||||
|
||||
void enqueue_ndrange_kernel(cl_kernel kernel, cl_uint work_dim, size_t *global_work_size, size_t *local_work_size, const ggml_tensor * tensor) {
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
cl_event evt;
|
||||
@@ -1003,6 +1017,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mat_f16_f32_tiled
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mat_f16_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mat_f16_f32.cl");
|
||||
#endif
|
||||
backend_ctx->program_mul_mat_f16_f32_tiled =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -1431,6 +1461,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
}
|
||||
}
|
||||
|
||||
// set_rows
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "set_rows.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("set_rows.cl");
|
||||
#endif
|
||||
backend_ctx->program_set_rows =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_set_rows_f32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_set_rows_f16 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_id_q4_0_f32_8x_flat
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -2233,8 +2280,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
return false;
|
||||
} break;
|
||||
#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
@@ -3374,6 +3431,111 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(src1);
|
||||
GGML_ASSERT(src1->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
// ne0 = ne00
|
||||
// ne2 = ne02
|
||||
// ne3 = ne03
|
||||
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
const int ne03 = src0->ne[3];
|
||||
|
||||
const cl_ulong nb01 = src0->nb[1];
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
const cl_ulong nb03 = src0->nb[3];
|
||||
|
||||
const int ne11 = src1->ne[1];
|
||||
const int ne12 = src1->ne[2];
|
||||
|
||||
const cl_ulong nb10 = src1->nb[0];
|
||||
const cl_ulong nb11 = src1->nb[1];
|
||||
const cl_ulong nb12 = src1->nb[2];
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
|
||||
const cl_ulong nb1 = dst->nb[1];
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
const int nblk0 = ne0/ggml_blck_size(dst->type);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_set_rows_f32;
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
kernel = backend_ctx->kernel_set_rows_f16;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("not implemented");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &nblk0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb3));
|
||||
|
||||
int nth0 = 64;
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 32;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
}
|
||||
|
||||
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
|
||||
while (nth0 < nblk0 && nth0 < max_workgroup_size) {
|
||||
nth0 *= 2;
|
||||
}
|
||||
|
||||
int rows_per_workgroup = 1;
|
||||
if (nth0 > nblk0) {
|
||||
rows_per_workgroup = nth0 / nblk0;
|
||||
nth0 = nblk0;
|
||||
}
|
||||
|
||||
size_t global_work_size[] = {
|
||||
(size_t)(ne01 + rows_per_workgroup - 1)/rows_per_workgroup*nth0,
|
||||
(size_t)ne02*rows_per_workgroup,
|
||||
(size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth0, (size_t)rows_per_workgroup, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -4784,6 +4946,58 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int M = src0->ne[1];
|
||||
const int N = src1->ne[1];
|
||||
const int K = src0->ne[0];
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int), &M));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &N));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &K));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));
|
||||
|
||||
// Tiling parameters. These need to be tuned for optimal performance.
|
||||
// They must match the #defines in the kernel mul_mat_f16_f32.cl.
|
||||
//
|
||||
// OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN.
|
||||
// TPWM / TPWN: Threads per Work-group. This is the work-group size.
|
||||
// OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements.
|
||||
//
|
||||
// The following relationships must hold:
|
||||
// OPWM = TPWM * OPTM
|
||||
// OPWN = TPWN * OPTN
|
||||
//
|
||||
const int OPWM = 64;
|
||||
const int OPWN = 64;
|
||||
const int TPWM = 16;
|
||||
const int TPWN = 8;
|
||||
|
||||
size_t local_work_size[2] = { TPWM, TPWN };
|
||||
size_t global_work_size[2] = {
|
||||
(size_t) ((M + OPWM - 1) / OPWM) * TPWM,
|
||||
(size_t) ((N + OPWN - 1) / OPWN) * TPWN,
|
||||
};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -4797,6 +5011,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
|
||||
src0->ne[1] > 32 && // M > 32
|
||||
src1->ne[1] > 32 && // N > 32
|
||||
src0->ne[0] > 32 && // K > 32
|
||||
src0->ne[2] == 1 && src0->ne[3] == 1 &&
|
||||
src1->ne[2] == 1 && src1->ne[3] == 1 &&
|
||||
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
|
||||
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
|
||||
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
@@ -5587,7 +5813,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(scale));
|
||||
float bias;
|
||||
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
||||
memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
@@ -5602,6 +5830,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));
|
||||
|
||||
int n = ggml_nelements(dst)/4;
|
||||
|
||||
@@ -6385,6 +6614,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||
}
|
||||
func = ggml_cl_get_rows;
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cl_set_rows;
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#if defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#else
|
||||
#define REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
|
||||
#define OPWM 64
|
||||
#define OPWN 64
|
||||
#define CPWK 8
|
||||
#define OPTM 4
|
||||
#define OPTN 8
|
||||
|
||||
#define WG_M (OPWM / OPTM)
|
||||
#define WG_N (OPWN / OPTN)
|
||||
#define VEC_K (CPWK / 4)
|
||||
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
__kernel void mul_mat_f16_f32(
|
||||
const int M, const int N, const int K,
|
||||
__global const void* A_void, ulong A_offset,
|
||||
__global const void* B_void, ulong B_offset,
|
||||
__global void* C_void, ulong C_offset) {
|
||||
|
||||
__global const half* A = (__global const half* )((__global const char*)A_void + A_offset);
|
||||
__global const float* B = (__global const float*)((__global const char*)B_void + B_offset);
|
||||
__global float* C = (__global float*)((__global char*)C_void + C_offset);
|
||||
|
||||
const int lidm = get_local_id(0);
|
||||
const int lidn = get_local_id(1);
|
||||
const int lid = lidn * WG_M + lidm;
|
||||
|
||||
const int offsetM = get_group_id(0) * OPWM;
|
||||
const int offsetN = get_group_id(1) * OPWN;
|
||||
|
||||
__local half4 Alocal[OPWM][VEC_K];
|
||||
__local float4 Blocal[OPWN][VEC_K];
|
||||
|
||||
float sum[OPTM][OPTN];
|
||||
|
||||
for (int wm = 0; wm < OPTM; wm++) {
|
||||
for (int wn = 0; wn < OPTN; wn++) {
|
||||
sum[wm][wn] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
const int numTiles = (K + CPWK - 1) / CPWK;
|
||||
|
||||
const int load_row_a = lid % OPWM;
|
||||
const int load_vec_k_a = lid / OPWM;
|
||||
const int global_row_a = offsetM + load_row_a;
|
||||
|
||||
const int load_row_b = lid % OPWN;
|
||||
const int load_vec_k_b = lid / OPWN;
|
||||
const int global_row_b = offsetN + load_row_b;
|
||||
|
||||
for (int t = 0; t < numTiles; t++) {
|
||||
const int k_start = t * CPWK;
|
||||
const int k_vec_start_a = k_start + load_vec_k_a * 4;
|
||||
const int k_vec_start_b = k_start + load_vec_k_b * 4;
|
||||
|
||||
if (global_row_a < M && k_vec_start_a < K) {
|
||||
if (k_vec_start_a + 3 < K) {
|
||||
Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);
|
||||
} else {
|
||||
half4 tempA = (half4)(0.0h);
|
||||
if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];
|
||||
if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];
|
||||
if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];
|
||||
Alocal[load_row_a][load_vec_k_a] = tempA;
|
||||
}
|
||||
} else {
|
||||
Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);
|
||||
}
|
||||
|
||||
if (global_row_b < N && k_vec_start_b < K) {
|
||||
if (k_vec_start_b + 3 < K) {
|
||||
Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);
|
||||
} else {
|
||||
float4 tempB = (float4)(0.0f);
|
||||
if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];
|
||||
if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];
|
||||
if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];
|
||||
Blocal[load_row_b][load_vec_k_b] = tempB;
|
||||
}
|
||||
} else {
|
||||
Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
#pragma unroll
|
||||
for (int k_vec = 0; k_vec < VEC_K; k_vec++) {
|
||||
float4 a_fvecs[OPTM];
|
||||
int current_row_a = lidm;
|
||||
for (int wm = 0; wm < OPTM; wm++) {
|
||||
a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);
|
||||
current_row_a += WG_M;
|
||||
}
|
||||
|
||||
float4 b_fvecs[OPTN];
|
||||
int current_row_b = lidn;
|
||||
for (int wn = 0; wn < OPTN; wn++) {
|
||||
b_fvecs[wn] = Blocal[current_row_b][k_vec];
|
||||
current_row_b += WG_N;
|
||||
}
|
||||
|
||||
for (int wm = 0; wm < OPTM; wm++) {
|
||||
for (int wn = 0; wn < OPTN; wn++) {
|
||||
sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
for (int wm = 0; wm < OPTM; wm++) {
|
||||
int globalRow = offsetM + lidm + wm * WG_M;
|
||||
if (globalRow < M) {
|
||||
for (int wn = 0; wn < OPTN; wn++) {
|
||||
int globalCol = offsetN + lidn + wn * WG_N;
|
||||
if (globalCol < N) {
|
||||
C[globalCol * M + globalRow] = sum[wm][wn];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,9 +8,10 @@ kernel void kernel_scale(
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd,
|
||||
float scale
|
||||
float scale,
|
||||
float bias
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
|
||||
dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
kernel void kernel_set_rows_f32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne11,
|
||||
int ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = i03%ne12;
|
||||
int i11 = i02%ne11;
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global float * dst_row = (global float *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
|
||||
dst_row[ind] = (float)src_row[ind];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_f16(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne11,
|
||||
int ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = i03%ne12;
|
||||
int i11 = i02%ne11;
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global half * dst_row = (global half *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
|
||||
dst_row[ind] = src_row[ind];
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,7 @@
|
||||
#include "outprod.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "set_rows.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "wkv.hpp"
|
||||
|
||||
+15
-27
@@ -32,39 +32,28 @@ public:
|
||||
else static_assert(0);
|
||||
}
|
||||
|
||||
// matrix A has m rows, k columns
|
||||
// matrix B has k rows, n columns
|
||||
// nra - number of elements to skip when moving into next row in A
|
||||
// nrb - number of elements to skip when moving into next row in B
|
||||
// nca - number of elements to skip when moving into next column in A
|
||||
// ncb - number of elements to skip when moving into next column in B
|
||||
// stride_a - number of elements to skip when moving to next A matrix
|
||||
// stride_b - number of elements to skip when moving to next B matrix
|
||||
// batches_a - number of A matrices
|
||||
// batches_b - number of B matrices
|
||||
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
||||
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
|
||||
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
|
||||
const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,
|
||||
const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,
|
||||
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
|
||||
|
||||
auto stream = ctx.stream_dnnl(q);
|
||||
auto eng = ctx.engine_dnnl(q);
|
||||
|
||||
// { # strides, # rows, # columns }
|
||||
dnnl::memory::dims a_dims = { batches_a, m, k };
|
||||
dnnl::memory::dims b_dims = { batches_b, k, n };
|
||||
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
|
||||
|
||||
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
|
||||
dnnl::memory::dims a_strides = { stride_a, nra, nca };
|
||||
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
|
||||
|
||||
dnnl::memory::dims a_dims = {batches_a, m, k };
|
||||
dnnl::memory::dims a_strides = {stra2, stra1, stra0};
|
||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
|
||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
|
||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
|
||||
|
||||
dnnl::memory::dims b_dims = {batches_b, k, n };
|
||||
dnnl::memory::dims b_strides = {strb2, strb0, strb1};
|
||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
|
||||
|
||||
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};
|
||||
dnnl::memory::dims c_strides = {m*n, 1, m };
|
||||
const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides);
|
||||
dnnl::primitive_attr primitive_attr;
|
||||
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
#ifdef GGML_SYCL_F16
|
||||
primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
|
||||
#endif
|
||||
@@ -76,24 +65,23 @@ public:
|
||||
|
||||
auto scratchpad_md = matmul_pd.scratchpad_desc();
|
||||
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
|
||||
|
||||
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||
|
||||
std::unordered_map<int, dnnl::memory> matmul_args;
|
||||
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
||||
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
||||
|
||||
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
||||
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
|
||||
|
||||
matmul_prim.execute(stream, matmul_args);
|
||||
}
|
||||
|
||||
// matrices A and B are column major, both having k rows
|
||||
// matrix A has m column, matrix B has n columns
|
||||
// output: column major matrix C = A transposed * B
|
||||
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
||||
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
||||
|
||||
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
|
||||
gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@
|
||||
#include "ggml-sycl/element_wise.hpp"
|
||||
#include "ggml-sycl/presets.hpp"
|
||||
#include "ggml-sycl/gemm.hpp"
|
||||
#include "ggml-sycl/set_rows.hpp"
|
||||
#include "ggml-sycl/sycl_hw.hpp"
|
||||
#include "ggml-sycl/getrows.hpp"
|
||||
#include "ggml.h"
|
||||
@@ -1545,7 +1546,7 @@ static void mul_mat_p021_f16_f32(
|
||||
|
||||
static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
|
||||
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
|
||||
const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const sycl::half *x = (const sycl::half *)vx;
|
||||
@@ -1556,7 +1557,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||
item_ct1.get_local_id(0);
|
||||
const int channel_x = channel / channel_x_divisor;
|
||||
|
||||
const int nrows_y = ncols_x;
|
||||
const int nrows_dst = nrows_x;
|
||||
const int row_dst = row_x;
|
||||
|
||||
@@ -1575,7 +1575,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||
const int row_y = col_x;
|
||||
|
||||
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
|
||||
const int iy = channel*nrows_y + row_y;
|
||||
const int iy = channel * channel_stride_y + row_y;
|
||||
|
||||
const float xi =
|
||||
sycl::vec<sycl::half, 1>(x[ix])
|
||||
@@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
|
||||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
||||
}
|
||||
|
||||
static void scale_f32(const float * x, float * dst, const float scale, const int k,
|
||||
static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
|
||||
return;
|
||||
}
|
||||
|
||||
dst[i] = scale * x[i];
|
||||
dst[i] = scale * x[i] + bias;
|
||||
}
|
||||
|
||||
|
||||
@@ -1822,7 +1822,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
||||
static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||
const void *vx, const float *y, float *dst, const int ncols_x,
|
||||
const int nrows_x, const int row_stride_x, const int nchannels_x,
|
||||
const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
|
||||
const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
|
||||
|
||||
const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
@@ -1834,7 +1834,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
||||
row_stride_x, channel_stride_x,
|
||||
row_stride_x, channel_stride_x, channel_stride_y,
|
||||
nchannels_y / nchannels_x, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -1842,7 +1842,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||
|
||||
|
||||
|
||||
static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
||||
static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
|
||||
const int k, queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@@ -1850,7 +1850,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
|
||||
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
scale_f32(x, dst, scale, k, item_ct1);
|
||||
scale_f32(x, dst, scale, bias, k, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2123,8 +2123,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
|
||||
#if GGML_SYCL_DNNL
|
||||
if (!g_ggml_sycl_disable_dnn) {
|
||||
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||
DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
||||
}
|
||||
else
|
||||
@@ -2170,8 +2170,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
|
||||
#if GGML_SYCL_DNNL
|
||||
if (!g_ggml_sycl_disable_dnn) {
|
||||
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
||||
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
||||
DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
|
||||
DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
||||
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
||||
}
|
||||
else
|
||||
@@ -2319,9 +2319,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
float bias;
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
|
||||
scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
|
||||
/*
|
||||
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
|
||||
error codes. The call was replaced with 0. You need to rewrite this code.
|
||||
@@ -2773,6 +2775,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
||||
const int64_t nb02 = src0->nb[2];
|
||||
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t nb11 = src1->nb[1];
|
||||
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
queue_ptr main_stream = ctx.stream();
|
||||
@@ -2783,8 +2786,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
|
||||
|
||||
const int64_t row_stride_x = nb01 / sizeof(sycl::half);
|
||||
const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
|
||||
const int64_t channel_stride_y = nb11 / sizeof(float);
|
||||
|
||||
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
||||
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
@@ -2838,8 +2842,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
float * dst_ddf = static_cast<float *>(dst->data);
|
||||
|
||||
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
||||
const size_t type_size_src0 = ggml_type_size(src0->type);
|
||||
const size_t type_size_src1 = ggml_type_size(src1->type);
|
||||
GGML_ASSERT(nb10 == type_size_src1);
|
||||
|
||||
// SRC1 strides
|
||||
int64_t s11 = nb11 / type_size_src1;
|
||||
@@ -2851,11 +2855,40 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
if (src1->type != GGML_TYPE_F16) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
|
||||
" : converting src1 to fp16");
|
||||
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
||||
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
||||
|
||||
// iterate tensor dims and find the slowest moving dim and stride
|
||||
int64_t last_dim=0;
|
||||
int64_t last_str=0;
|
||||
int64_t largest_str=0;
|
||||
for(int i = 0; i< 4; i++){
|
||||
// last stride is always the largest
|
||||
if(src1->nb[i] == largest_str){
|
||||
if(src1->ne[last_dim] == 1){
|
||||
last_str = i;
|
||||
last_dim = i;
|
||||
}
|
||||
}
|
||||
if(src1->nb[i] > largest_str){
|
||||
largest_str = src1->nb[i];
|
||||
last_str = i;
|
||||
last_dim = i;
|
||||
}
|
||||
|
||||
}
|
||||
#if GGML_SYCL_DNNL
|
||||
// oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
|
||||
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
|
||||
src1_f16_alloc.alloc(ne_src1);
|
||||
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
|
||||
GGML_ASSERT(to_fp16_sycl != nullptr);
|
||||
to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
|
||||
# else
|
||||
const int64_t ne_src1 = ggml_nelements(src1);
|
||||
src1_f16_alloc.alloc(ne_src1);
|
||||
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
||||
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
||||
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
||||
#endif
|
||||
|
||||
src1_f16 = src1_f16_alloc.get();
|
||||
s11 = ne10;
|
||||
@@ -2889,38 +2922,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
|
||||
#if GGML_SYCL_DNNL
|
||||
if (!g_ggml_sycl_disable_dnn) {
|
||||
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
|
||||
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
|
||||
int64_t str_a0 = nb00 / type_size_src0;
|
||||
int64_t str_a1 = nb01 / type_size_src0;
|
||||
int64_t str_a2 = nb02 / type_size_src0;
|
||||
|
||||
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
|
||||
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
|
||||
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
|
||||
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
|
||||
};
|
||||
int64_t str_b0 = nb10 / type_size_src1;
|
||||
int64_t str_b1 = nb11 / type_size_src1;
|
||||
int64_t str_b2 = nb12 / type_size_src1;
|
||||
|
||||
if (r2 == 1 && r3 == 1) {
|
||||
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
|
||||
}
|
||||
else {
|
||||
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
||||
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
|
||||
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
|
||||
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
|
||||
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
|
||||
auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
|
||||
const sycl::half *src1, float *dst,
|
||||
int64_t a0, int64_t a1, int64_t batcha,
|
||||
int64_t b0, int64_t b1, int64_t batchb,
|
||||
int64_t sa0, int64_t sa1, int64_t sa2,
|
||||
int64_t sb0, int64_t sb1, int64_t sb2,
|
||||
int64_t sd2) {
|
||||
bool supported_broadcast = batchb == batcha ? true
|
||||
: batchb == 1 || batcha == 1 ? true
|
||||
: false;
|
||||
if (supported_broadcast) {
|
||||
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
|
||||
DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
|
||||
} else {
|
||||
// iterate over batches from smaller set of matrices (matrix 0)
|
||||
int64_t batches0 = batcha;
|
||||
int64_t batches1 = batchb;
|
||||
|
||||
if (batches0 > batches1) {
|
||||
int64_t num_mul_mats = batches1;
|
||||
int64_t sub_batch = batches0 / num_mul_mats;
|
||||
// src0 is batched and bigger, shift and multiply with src1
|
||||
for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
|
||||
const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
|
||||
const sycl::half *src1_shifted = src1 + (sb2 * i0);
|
||||
float *dst_shifted = dst + (sd2 * i0 * sub_batch);
|
||||
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
|
||||
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
|
||||
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
|
||||
queue, sub_batch, 1);
|
||||
}
|
||||
} else {
|
||||
int64_t num_mul_mats = batches0;
|
||||
int64_t sub_batch = batches1 / num_mul_mats;
|
||||
// src1 is batched and bigger, shift and multiply with src0
|
||||
for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
|
||||
const sycl::half *src0_shifted = src0 + (sa2 * i1);
|
||||
const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
|
||||
float *dst_shifted = dst + (sd2 * i1 * sub_batch);
|
||||
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
|
||||
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
|
||||
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
|
||||
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
|
||||
queue, 1, sub_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool cont_batches_a = nb02 * ne02 == nb03;
|
||||
bool cont_batches_b = nb12 * ne12 == nb13;
|
||||
if (cont_batches_a && cont_batches_b) {
|
||||
int64_t batches0 = ne02 * ne03;
|
||||
int64_t batches1 = ne12 * ne13;
|
||||
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
|
||||
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
|
||||
str_b2, nb2 / sizeof(float));
|
||||
} else {
|
||||
for (int64_t b_a = 0; b_a < ne03; b_a++) {
|
||||
const sycl::half *src0_f16_shifted
|
||||
= src0_f16 + (nb03 * b_a / type_size_src0);
|
||||
const sycl::half *src1_f16_shifted
|
||||
= src1_f16 + (nb13 * b_a / type_size_src1);
|
||||
float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
|
||||
int64_t batches0 = ne02;
|
||||
int64_t batches1 = ne12;
|
||||
launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
|
||||
ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
|
||||
str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// iterate over batches from smaller set of matrices (matrix 0)
|
||||
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
|
||||
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
||||
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
|
||||
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
|
||||
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
|
||||
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
else
|
||||
#endif
|
||||
@@ -3260,10 +3344,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
||||
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
||||
}
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV single-batch
|
||||
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
||||
} else if (use_dequantize_mul_mat_vec) {
|
||||
@@ -3603,6 +3687,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_sycl_get_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_sycl_op_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
ggml_sycl_dup(ctx, dst);
|
||||
break;
|
||||
@@ -4297,7 +4384,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
return false;
|
||||
#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
||||
return (op->type == GGML_TYPE_F32 || (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64));
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
|
||||
+15
-18
@@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
|
||||
|
||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row * ne0 + i0;
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
||||
return;
|
||||
}
|
||||
|
||||
const int row0 = row % ne1;
|
||||
const int channel0 = row / ne1;
|
||||
|
||||
const int i = row * ne0 + i0;
|
||||
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
@@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
||||
|
||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row * ne0 + i0;
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
||||
return;
|
||||
}
|
||||
|
||||
const int row0 = row % ne1;
|
||||
const int channel0 = row / ne1;
|
||||
|
||||
const int i = row * ne0 + i0 / 2;
|
||||
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
@@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
||||
}
|
||||
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
const int idst = (row_dst * ne0) + (i0 / 2);
|
||||
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
#include "set_rows.hpp"
|
||||
|
||||
namespace utils {
|
||||
template<typename T>
|
||||
static constexpr bool is_arithmetic_v() {
|
||||
return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TIn, typename TOut>
|
||||
static inline std::enable_if_t<utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void>
|
||||
convert (const char* src, char* dst) {
|
||||
auto src_val = *reinterpret_cast<const TIn*>(src);
|
||||
auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut, sycl::rounding_mode::automatic>()[0];
|
||||
*reinterpret_cast<TOut*>(dst) = dst_val;
|
||||
}
|
||||
|
||||
template<typename TIn, typename TOut>
|
||||
static void k_set_rows(
|
||||
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne11, const int64_t ne12,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
const size_t src_type_size, const size_t dst_type_size,
|
||||
const int64_t total_elements,
|
||||
const sycl::nd_item<1> & item_ct1) {
|
||||
|
||||
const int64_t i = item_ct1.get_global_linear_id();
|
||||
if (i >= total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = i / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
|
||||
|
||||
const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});
|
||||
const char * src_elem = src0_row + i00 * src_type_size;
|
||||
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
|
||||
char * dst_elem = dst_row_ptr + i00 * dst_type_size;
|
||||
|
||||
convert<TIn, TOut>(src_elem, dst_elem);
|
||||
}
|
||||
|
||||
template<typename TIn, typename TOut>
|
||||
static void set_rows_sycl(
|
||||
const char * src0_d, const int64_t * src1_d, char * dst_d,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
const size_t src_type_size, const size_t dst_type_size,
|
||||
queue_ptr stream) {
|
||||
|
||||
const int64_t total_elements = ne00 * ne01 * ne02 * ne03;
|
||||
|
||||
constexpr int block_size = 64;
|
||||
const int64_t grid_size = ceil_div(total_elements, block_size);
|
||||
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<1>(grid_size * block_size, block_size),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
k_set_rows<TIn, TOut>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02,
|
||||
ne11, ne12,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
src_type_size, dst_type_size,
|
||||
total_elements,
|
||||
item_ct1
|
||||
);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t * src1_dd = static_cast<const int64_t *>(src1->data);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
set_rows_sycl<float, float>(
|
||||
(const char *)src0->data, src1_dd, (char *)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne11, ne12,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
sizeof(float), sizeof(float),
|
||||
stream
|
||||
);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
set_rows_sycl<float, sycl::half>(
|
||||
(const char *)src0->data, src1_dd, (char *)dst->data,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne11, ne12,
|
||||
nb01, nb02, nb03,
|
||||
nb10, nb11, nb12,
|
||||
nb1, nb2, nb3,
|
||||
sizeof(float), sizeof(sycl::half),
|
||||
stream
|
||||
);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported tensor type!");
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
#ifndef GGML_SYCL_SET_ROWS_HPP
|
||||
#define GGML_SYCL_SET_ROWS_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_SET_ROWS_HPP
|
||||
@@ -425,18 +425,20 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_div_norepeat[2][2][2];
|
||||
|
||||
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
||||
vk_pipeline pipeline_upscale_f32;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sin_f32;
|
||||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_roll_f32;
|
||||
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
||||
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
|
||||
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
|
||||
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_norm_f32;
|
||||
vk_pipeline pipeline_group_norm_f32;
|
||||
vk_pipeline pipeline_rms_norm_f32;
|
||||
@@ -501,6 +503,8 @@ struct vk_device_struct {
|
||||
|
||||
ggml_backend_buffer_type buffer_type;
|
||||
|
||||
bool disable_fusion;
|
||||
|
||||
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
||||
std::unique_ptr<vk_memory_logger> memory_logger;
|
||||
#endif
|
||||
@@ -691,6 +695,37 @@ struct vk_op_unary_push_constants {
|
||||
};
|
||||
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
|
||||
|
||||
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
|
||||
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
|
||||
ne = ne != 0 ? ne : ggml_nelements(dst);
|
||||
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
|
||||
|
||||
vk_op_unary_push_constants p{};
|
||||
p.ne = (uint32_t)ne;
|
||||
|
||||
size_t src0_tsize = ggml_type_size(src0->type);
|
||||
p.ne00 = (uint32_t)src0->ne[0];
|
||||
p.ne01 = (uint32_t)src0->ne[1];
|
||||
p.ne02 = (uint32_t)src0->ne[2];
|
||||
p.ne03 = (uint32_t)src0->ne[3];
|
||||
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
|
||||
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
|
||||
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
|
||||
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
|
||||
|
||||
size_t dst_tsize = ggml_type_size(dst->type);
|
||||
p.ne10 = (uint32_t)dst->ne[0];
|
||||
p.ne11 = (uint32_t)dst->ne[1];
|
||||
p.ne12 = (uint32_t)dst->ne[2];
|
||||
p.ne13 = (uint32_t)dst->ne[3];
|
||||
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
|
||||
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
|
||||
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
|
||||
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
|
||||
|
||||
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
||||
}
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
@@ -860,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
|
||||
|
||||
struct vk_op_upscale_push_constants {
|
||||
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
||||
uint32_t ne00; uint32_t ne01;
|
||||
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
@@ -1091,8 +1127,8 @@ static size_t vk_skip_checks;
|
||||
static size_t vk_output_tensor;
|
||||
|
||||
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
|
||||
static void ggml_vk_check_results_0(ggml_tensor * tensor);
|
||||
static void ggml_vk_check_results_1(ggml_tensor * tensor);
|
||||
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
||||
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
||||
#endif
|
||||
|
||||
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
@@ -1733,7 +1769,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
||||
|
||||
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
|
||||
if (hsv >= 512) {
|
||||
return 2;
|
||||
} else {
|
||||
return 8;
|
||||
}
|
||||
}
|
||||
|
||||
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
||||
// 128 threads split into four subgroups, each subgroup does 1/4
|
||||
@@ -1758,7 +1801,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, 64};
|
||||
} else {
|
||||
return {scalar_flash_attention_num_large_rows, 32};
|
||||
return {get_fa_scalar_num_large_rows(hsv), 32};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1777,7 +1820,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
||||
|
||||
// small cols to reduce register count
|
||||
if (ggml_is_quantized(type) || hsk >= 256) {
|
||||
return {64, 32};
|
||||
if (hsk >= 512) {
|
||||
return {32, 32};
|
||||
} else {
|
||||
return {64, 32};
|
||||
}
|
||||
}
|
||||
return {64, 64};
|
||||
}
|
||||
@@ -1819,7 +1866,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
const uint32_t warps = warptile[0] / warptile[10];
|
||||
|
||||
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
|
||||
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
||||
|
||||
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
||||
@@ -1944,10 +1991,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
||||
|
||||
// spec constants and tile sizes for quant matmul_id
|
||||
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
|
||||
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
||||
l_mmqid_wg_denoms = { 128, 64, 1 };
|
||||
l_mmqid_wg_denoms = { 128, 128, 1 };
|
||||
m_mmqid_wg_denoms = { 128, 64, 1 };
|
||||
s_mmqid_wg_denoms = { 128, 64, 1 };
|
||||
|
||||
@@ -2704,7 +2751,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
||||
@@ -2736,19 +2783,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
||||
}
|
||||
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||
@@ -2766,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
return s;
|
||||
};
|
||||
|
||||
bool rte = device->float_controls_rte_fp16;
|
||||
#define CREATE_BINARY(name, namemod, spec) \
|
||||
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
|
||||
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
|
||||
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
|
||||
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
|
||||
|
||||
CREATE_BINARY(add, , {0})
|
||||
@@ -2788,7 +2858,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -2800,6 +2872,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -2817,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
#undef CREATE_UNARY
|
||||
|
||||
#define CREATE_GLU(name) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||
}
|
||||
|
||||
CREATE_GLU(geglu)
|
||||
CREATE_GLU(reglu)
|
||||
@@ -3507,6 +3586,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
device->idx = idx;
|
||||
|
||||
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
|
||||
|
||||
return device;
|
||||
}
|
||||
|
||||
@@ -4841,7 +4922,7 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
||||
return
|
||||
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
||||
tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
|
||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||
(tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
|
||||
@@ -6044,7 +6125,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
||||
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
@@ -6169,7 +6250,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
case FA_SCALAR:
|
||||
case FA_COOPMAT1:
|
||||
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
||||
max_gqa = scalar_flash_attention_num_large_rows;
|
||||
max_gqa = get_fa_scalar_num_large_rows(HSV);
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
||||
@@ -6248,13 +6329,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||
|
||||
// Try to use split_k when KV is large enough to be worth the overhead
|
||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||
if (workgroups_x == 1 && shader_core_count > 0) {
|
||||
// Try to run two workgroups per SM.
|
||||
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
|
||||
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
|
||||
split_k = CEIL_DIV(KV, split_kv);
|
||||
workgroups_x = split_k;
|
||||
}
|
||||
@@ -6388,7 +6469,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||
},
|
||||
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
|
||||
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
||||
} else {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
@@ -6464,8 +6545,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_UPSCALE:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
|
||||
return ctx->device->pipeline_upscale_f32;
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
int mode = ggml_get_op_params_i32(dst, 0);
|
||||
switch (mode) {
|
||||
case GGML_SCALE_MODE_NEAREST:
|
||||
return ctx->device->pipeline_upscale_nearest_f32;
|
||||
case GGML_SCALE_MODE_BILINEAR:
|
||||
return ctx->device->pipeline_upscale_bilinear_f32;
|
||||
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
|
||||
return ctx->device->pipeline_upscale_bilinear_ac_f32;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SCALE:
|
||||
@@ -6498,6 +6587,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_pad_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ROLL:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_roll_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_REPEAT:
|
||||
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
|
||||
return ctx->device->pipeline_repeat_f32;
|
||||
@@ -6512,6 +6606,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
|
||||
case GGML_OP_SET_ROWS:
|
||||
return ctx->device->pipeline_set_rows[dst->type];
|
||||
case GGML_OP_SILU_BACK:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_silu_back_f32;
|
||||
@@ -6750,6 +6846,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_SET_ROWS:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -7044,6 +7141,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_CPY:
|
||||
@@ -7063,6 +7161,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
ne *= ggml_type_size(src0->type) / 2;
|
||||
}
|
||||
}
|
||||
// copy_to_quant has block size of 32, and each thread does QUANT_K elements.
|
||||
// Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
|
||||
// So divide by block size here before splitting into 512x512 groups.
|
||||
if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
||||
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
|
||||
}
|
||||
if (ne > 262144) {
|
||||
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
||||
} else if (ne > 512) {
|
||||
@@ -7071,6 +7175,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
elements = { ne, 1, 1 };
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
uint32_t ne = ggml_nelements(src0);
|
||||
if (ggml_is_quantized(dst->type)) {
|
||||
// quants run 32 threads each doing QUANT_K elements
|
||||
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
|
||||
} else {
|
||||
// scalar types do one element per thread, running 512 threads
|
||||
ne = CEIL_DIV(ne, 512);
|
||||
}
|
||||
if (ne > 262144) {
|
||||
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
||||
} else if (ne > 512) {
|
||||
elements = { 512, CEIL_DIV(ne, 512), 1 };
|
||||
} else {
|
||||
elements = { ne, 1, 1 };
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
|
||||
break;
|
||||
@@ -7480,14 +7603,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
|
||||
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
const float sf0 = (float)dst->ne[0] / src0->ne[0];
|
||||
const float sf1 = (float)dst->ne[1] / src0->ne[1];
|
||||
const float sf2 = (float)dst->ne[2] / src0->ne[2];
|
||||
const float sf3 = (float)dst->ne[3] / src0->ne[3];
|
||||
float sf0 = (float)dst->ne[0] / src0->ne[0];
|
||||
float sf1 = (float)dst->ne[1] / src0->ne[1];
|
||||
float sf2 = (float)dst->ne[2] / src0->ne[2];
|
||||
float sf3 = (float)dst->ne[3] / src0->ne[3];
|
||||
|
||||
if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
||||
}
|
||||
|
||||
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
|
||||
(uint32_t)ggml_nelements(dst), 0, 0,
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
|
||||
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
|
||||
sf0, sf1, sf2, sf3,
|
||||
@@ -7495,123 +7625,64 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||
}
|
||||
|
||||
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = ggml_get_op_params_f32(dst, 0);
|
||||
p.param2 = ggml_get_op_params_f32(dst, 1);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
op_params[0], 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = ggml_get_op_params_f32(dst, 0);
|
||||
p.param2 = ggml_get_op_params_f32(dst, 1);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
op_params[0], op_params[1],
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
|
||||
(uint32_t)ggml_nelements(dst),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
|
||||
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
|
||||
const int32_t s2 = ggml_get_op_params_i32(dst, 2);
|
||||
const int32_t s3 = ggml_get_op_params_i32(dst, 3);
|
||||
const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
|
||||
const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
|
||||
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
memcpy(&p.param1, &s01_packed, sizeof(float));
|
||||
memcpy(&p.param2, &s23_packed, sizeof(float));
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
|
||||
(uint32_t)ggml_nelements(dst),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
|
||||
(uint32_t)ggml_nelements(dst),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
}, dryrun);
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
uint32_t ne = (uint32_t)ggml_nelements(src0);
|
||||
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
||||
// Convert from number of logical elements to 2- or 4-byte units.
|
||||
@@ -7623,13 +7694,22 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
}
|
||||
}
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
|
||||
ne,
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0.0f, 0.0f, 0,
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
@@ -7654,8 +7734,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
@@ -8885,7 +8964,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
||||
@@ -8953,7 +9032,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SILU_BACK:
|
||||
@@ -9020,6 +9101,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SILU_BACK:
|
||||
@@ -9122,12 +9204,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_PAD:
|
||||
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_ROLL:
|
||||
ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_SILU_BACK:
|
||||
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
@@ -9146,9 +9236,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
// fused rms_norm + mul
|
||||
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
||||
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
|
||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
|
||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
|
||||
} else {
|
||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
|
||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
@@ -9308,7 +9398,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
|
||||
ctx->compute_ctx.reset();
|
||||
|
||||
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
|
||||
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
|
||||
if (!ok) {
|
||||
if (node->op == GGML_OP_UNARY) {
|
||||
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
||||
@@ -9323,7 +9413,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
||||
GGML_UNUSED(cgraph);
|
||||
ggml_backend_buffer * buf = nullptr;
|
||||
|
||||
switch (tensor->op) {
|
||||
@@ -9341,7 +9432,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SILU_BACK:
|
||||
@@ -9433,7 +9526,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
// Only run if ctx hasn't been submitted yet
|
||||
if (!subctx->seqs.empty()) {
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
ggml_vk_check_results_0(tensor);
|
||||
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
|
||||
use_fence = true;
|
||||
#endif
|
||||
|
||||
@@ -9453,7 +9546,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
ggml_vk_wait_for_fence(ctx);
|
||||
}
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
ggml_vk_check_results_1(tensor);
|
||||
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -9900,6 +9993,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
|
||||
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
||||
// additional constraints specific to this fusion
|
||||
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
||||
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
||||
|
||||
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
||||
// rms_norm only supports f32
|
||||
if (mul->src[0]->type != GGML_TYPE_F32 ||
|
||||
mul->src[1]->type != GGML_TYPE_F32 ||
|
||||
mul->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
// if rms_norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] &&
|
||||
mul->src[0]->ne[1] != rms_norm->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
// rms_norm shader assumes contiguous rows
|
||||
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
@@ -9913,7 +10037,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
|
||||
uint64_t total_mat_mul_bytes = 0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
}
|
||||
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
||||
@@ -9983,7 +10107,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
||||
}
|
||||
|
||||
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
}
|
||||
|
||||
@@ -10232,10 +10356,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
||||
return false;
|
||||
}
|
||||
// Check against size of shared memory variable
|
||||
if (op->src[2]->ne[0] > 4096) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_F32:
|
||||
@@ -10376,9 +10496,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
// TODO: add support
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||
return false;
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_CPY:
|
||||
@@ -10464,13 +10595,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_ARGSORT:
|
||||
@@ -10760,11 +10890,21 @@ void * comp_result;
|
||||
size_t comp_size;
|
||||
size_t comp_nb[GGML_MAX_DIMS];
|
||||
size_t check_counter = 0;
|
||||
static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
||||
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
||||
if (tensor->op == GGML_OP_TRANSPOSE) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool fused_rms_norm_mul = false;
|
||||
int rms_norm_idx = -1;
|
||||
if (ctx->num_additional_fused_ops == 1 &&
|
||||
tensor->op == GGML_OP_RMS_NORM &&
|
||||
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
||||
fused_rms_norm_mul = true;
|
||||
tensor = cgraph->nodes[tensor_idx + 1];
|
||||
}
|
||||
|
||||
check_counter++;
|
||||
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
||||
return;
|
||||
@@ -10792,6 +10932,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
ggml_tensor * srci = tensor->src[i];
|
||||
if (fused_rms_norm_mul) {
|
||||
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
|
||||
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
|
||||
switch (i) {
|
||||
case 0: srci = rms_norm->src[0]; break;
|
||||
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
|
||||
default: continue;
|
||||
}
|
||||
}
|
||||
if (srci == nullptr) {
|
||||
continue;
|
||||
}
|
||||
@@ -10849,7 +10998,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else if (tensor->op == GGML_OP_SUB) {
|
||||
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_MUL) {
|
||||
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
if (fused_rms_norm_mul) {
|
||||
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
|
||||
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
|
||||
} else {
|
||||
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
}
|
||||
} else if (tensor->op == GGML_OP_DIV) {
|
||||
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_CONCAT) {
|
||||
@@ -10969,6 +11123,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else {
|
||||
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
}
|
||||
} else if (tensor->op == GGML_OP_SET_ROWS) {
|
||||
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_CONT) {
|
||||
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
} else if (tensor->op == GGML_OP_RESHAPE) {
|
||||
@@ -11040,10 +11196,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
|
||||
ggml_build_forward_expand(cgraph, tensor_clone);
|
||||
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
|
||||
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
|
||||
|
||||
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
|
||||
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
|
||||
|
||||
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
|
||||
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
|
||||
@@ -11066,10 +11222,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
|
||||
}
|
||||
|
||||
static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
||||
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
||||
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
||||
if (tensor->op == GGML_OP_TRANSPOSE) {
|
||||
return;
|
||||
}
|
||||
bool fused_rms_norm_mul = false;
|
||||
if (ctx->num_additional_fused_ops == 1 &&
|
||||
tensor->op == GGML_OP_RMS_NORM &&
|
||||
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
||||
fused_rms_norm_mul = true;
|
||||
tensor = cgraph->nodes[tensor_idx + 1];
|
||||
}
|
||||
|
||||
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
#version 450
|
||||
|
||||
#if RTE16
|
||||
#extension GL_EXT_spirv_intrinsics : enable
|
||||
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
||||
#endif // RTE16
|
||||
|
||||
#include "rte.comp"
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
// 16 invocations needed for init_iq4nl_shmem
|
||||
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
|
||||
#if defined(SET_ROWS) && QUANT_K == 1
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint BLOCK_SIZE = 512;
|
||||
#else
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||
const uint BLOCK_SIZE = 32;
|
||||
#endif
|
||||
|
||||
layout (binding = 0) readonly buffer S {float data_s[];};
|
||||
|
||||
#if defined(SET_ROWS)
|
||||
#include "generic_binary_head.comp"
|
||||
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
|
||||
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
|
||||
#else
|
||||
#include "generic_unary_head.comp"
|
||||
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
void quantize(uint dst_idx, uint src_idx)
|
||||
@@ -221,15 +225,56 @@ void quantize(uint dst_idx, uint src_idx)
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
void quantize(uint dst_idx, uint src_idx)
|
||||
{
|
||||
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
void quantize(uint dst_idx, uint src_idx)
|
||||
{
|
||||
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(SET_ROWS)
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
if (gl_LocalInvocationIndex.x != 0) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
|
||||
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
uint i12 = fastmod(i03, p.ne12);
|
||||
uint i11 = fastmod(i02, p.ne11);
|
||||
uint i10 = i01;
|
||||
|
||||
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
|
||||
|
||||
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
|
||||
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
|
||||
|
||||
quantize(dst_idx, src0_idx);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
@@ -240,3 +285,5 @@ void main() {
|
||||
|
||||
quantize(dst_idx, src_idx);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
void main() {
|
||||
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||
const uint i = gl_WorkGroupID.x * 256 + wgy;
|
||||
if (i >= p.M * p.K / QUANT_K) {
|
||||
if (i >= p.nel / QUANT_K) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
void main() {
|
||||
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||
const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
|
||||
if (i >= p.M * p.K / QUANT_K) {
|
||||
if (i >= p.nel / QUANT_K) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
void main() {
|
||||
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||
const uint ib = gl_WorkGroupID.x * 256 + wgy;
|
||||
if (ib >= p.M * p.K / QUANT_K) {
|
||||
if (ib >= p.nel / QUANT_K) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
void main() {
|
||||
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||
const uint ib = gl_WorkGroupID.x * 256 + wgy;
|
||||
if (ib >= p.M * p.K / QUANT_K) {
|
||||
if (ib >= p.nel / QUANT_K) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
void main() {
|
||||
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
||||
const uint i = gl_WorkGroupID.x * 256 + wgy;
|
||||
if (i >= p.M * p.K / QUANT_K) {
|
||||
if (i >= p.nel / QUANT_K) {
|
||||
return;
|
||||
}
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||
@@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
|
||||
uint k_num;
|
||||
} p;
|
||||
|
||||
shared float tmpsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
// Each workgroup handles a row
|
||||
const uint n = gl_WorkGroupID.x;
|
||||
@@ -32,23 +34,51 @@ void main() {
|
||||
|
||||
// Compute the max m value for the row
|
||||
float m_max = -1.0/0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
m_max = max(m_max, m);
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = m_max;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
m_max = max(m_max, tmpsh[tid + s]);
|
||||
tmpsh[tid] = m_max;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
m_max = tmpsh[0];
|
||||
|
||||
barrier();
|
||||
|
||||
// Compute L based on m_max
|
||||
float L = 0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float l = data_a[l_offset + k * lm_stride];
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float l = data_a[l_offset + (k + tid) * lm_stride];
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
L += exp(m - m_max) * l;
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = L;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
L += tmpsh[tid + s];
|
||||
tmpsh[tid] = L;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
L = tmpsh[0];
|
||||
|
||||
L = 1.0 / L;
|
||||
|
||||
// D dimension is split across workgroups in the y dimension
|
||||
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
||||
// Scale and sum the O contributions based on m_max and store the result to memory
|
||||
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
||||
if (d < D) {
|
||||
float O = 0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "rte.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#include "rte.comp"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_spirv_intrinsics: enable
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#if RTE16
|
||||
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
||||
#endif
|
||||
#include "rte.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
uint _ne1;
|
||||
#ifdef COOPMAT
|
||||
shared uint _ne1_sh;
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
@@ -172,7 +177,47 @@ void main() {
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint _ne1 = 0;
|
||||
#ifdef COOPMAT
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
@@ -183,6 +228,7 @@ void main() {
|
||||
}
|
||||
|
||||
barrier();
|
||||
#endif
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
@@ -500,10 +546,9 @@ void main() {
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 32;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
@@ -512,22 +557,16 @@ void main() {
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4;
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32;
|
||||
const uint ib16 = ib8 / 2;
|
||||
const int i8 = 2 * int(idx % 4);
|
||||
|
||||
const uint16_t[4] scales = data_a[ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
@@ -538,21 +577,17 @@ void main() {
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
const ivec2 gvec = ivec2(
|
||||
bitfieldExtract(grid, 2 * (i8), 2),
|
||||
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
||||
);
|
||||
const vec2 v = dl * (vec2(gvec) + delta);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx / 4) % 4;
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
||||
@@ -562,63 +597,81 @@ void main() {
|
||||
data_a[ib].qs[8*ib32 + 6],
|
||||
data_a[ib].qs[8*ib32 + 7]
|
||||
));
|
||||
const float db = d * 0.25 * (0.5 + (signs >> 28));
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xxs_grid[qs];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx / 4) % 4; // 0..3
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4; // 0..3
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
||||
const uint sign7 = qs >> 9;
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xs_grid[qs & 511];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
|
||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
@@ -631,33 +684,36 @@ void main() {
|
||||
));
|
||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint iqh = iqs / 8;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint qh = data_a[ib].qh[iqh];
|
||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
|
||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
||||
const uint scale = data_a[ib].scales[iqs / 16];
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
@@ -162,17 +162,32 @@ void main() {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
}
|
||||
@@ -414,17 +429,31 @@ void main() {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
#endif
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
#endif
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
uint wrap_idx(int i, uint ne) {
|
||||
if (i < 0) {
|
||||
return i + ne;
|
||||
} else if (i >= ne) {
|
||||
return i - ne;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
|
||||
const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
|
||||
const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
|
||||
const uint i2_offset = i2*p.ne11*p.ne10;
|
||||
const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
|
||||
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
|
||||
|
||||
const uint p1 = floatBitsToUint(p.param1);
|
||||
const uint p2 = floatBitsToUint(p.param2);
|
||||
const int s0 = int(p1 >> 16) - 0x8000;
|
||||
const int s1 = int(p1 & 0xFFFF) - 0x8000;
|
||||
const int s2 = int(p2 >> 16) - 0x8000;
|
||||
const int s3 = int(p2 & 0xFFFF) - 0x8000;
|
||||
|
||||
const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
|
||||
const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
|
||||
const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
|
||||
const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
|
||||
|
||||
const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
|
||||
const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
|
||||
|
||||
data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
|
||||
}
|
||||
@@ -1,11 +1,8 @@
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_spirv_intrinsics: enable
|
||||
|
||||
#if RTE16
|
||||
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
||||
#endif
|
||||
#include "rte.comp"
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
|
||||
|
||||
|
||||
@@ -14,21 +14,19 @@ void main() {
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
||||
@@ -13,21 +13,19 @@ void main() {
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -13,21 +13,19 @@ void main() {
|
||||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + 0] = data_a[ix + 0];
|
||||
data_d[idst + 1] = data_a[ix + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
|
||||
#if RTE16
|
||||
#extension GL_EXT_spirv_intrinsics : enable
|
||||
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
||||
#endif // RTE16
|
||||
@@ -18,7 +18,7 @@ void main() {
|
||||
continue;
|
||||
}
|
||||
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
|
||||
idx += num_threads;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne; uint a_offset; uint d_offset;
|
||||
uint ne00; uint ne01;
|
||||
uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
@@ -15,6 +16,61 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
|
||||
#define NEAREST 0
|
||||
#define BILINEAR 1
|
||||
#define ALIGN_CORNERS (1 << 8)
|
||||
|
||||
layout (constant_id = 0) const uint scale_mode = 0;
|
||||
|
||||
float fetch_nearest(uint i10, uint i11, uint i12, uint i13) {
|
||||
const uint i00 = uint(i10 / p.sf0);
|
||||
const uint i01 = uint(i11 / p.sf1);
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
|
||||
return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];
|
||||
}
|
||||
|
||||
float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
|
||||
|
||||
const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];
|
||||
const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];
|
||||
const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];
|
||||
const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];
|
||||
|
||||
return
|
||||
v00 * (1.0-d.x) * (1.0-d.y) +
|
||||
v01 * d.x * (1.0-d.y) +
|
||||
v10 * (1.0-d.x) * d.y +
|
||||
v11 * d.x * d.y;
|
||||
}
|
||||
|
||||
float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
|
||||
const ivec2 ne0 = ivec2(p.ne00, p.ne01);
|
||||
|
||||
const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = max(ivec2(c0f), 0);
|
||||
const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);
|
||||
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
|
||||
const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
|
||||
const vec2 c0f = floor(c);
|
||||
const vec2 d = c - c0f;
|
||||
const ivec2 c0 = ivec2(c0f);
|
||||
const ivec2 c1 = c0 + 1;
|
||||
|
||||
return fetch_bilinear(c0, c1, d, i12, i13);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
@@ -27,10 +83,18 @@ void main() {
|
||||
const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
|
||||
const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
|
||||
|
||||
const uint i00 = uint(i10 / p.sf0);
|
||||
const uint i01 = uint(i11 / p.sf1);
|
||||
const uint i02 = uint(i12 / p.sf2);
|
||||
const uint i03 = uint(i13 / p.sf3);
|
||||
float result;
|
||||
switch (scale_mode) {
|
||||
case NEAREST:
|
||||
result = fetch_nearest(i10, i11, i12, i13);
|
||||
break;
|
||||
case BILINEAR:
|
||||
result = interpolate_bilinear(i10, i11, i12, i13);
|
||||
break;
|
||||
case BILINEAR | ALIGN_CORNERS:
|
||||
result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
|
||||
break;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
|
||||
data_d[p.d_offset + idx] = D_TYPE(result);
|
||||
}
|
||||
|
||||
@@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
std::string load_vec_quant = "2";
|
||||
if ((tname == "q4_0") || (tname == "q4_1"))
|
||||
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
|
||||
load_vec_quant = "8";
|
||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
|
||||
load_vec_quant = "4";
|
||||
|
||||
if (tname == "bf16") {
|
||||
@@ -518,6 +518,11 @@ void process_shaders() {
|
||||
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}
|
||||
|
||||
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
}
|
||||
|
||||
auto get_type_str = [](bool f16) {
|
||||
return f16 ? "float16_t" : "float";
|
||||
};
|
||||
@@ -532,8 +537,10 @@ void process_shaders() {
|
||||
for (auto src0_f16 : {false, true}) {
|
||||
for (auto src1_f16 : {false, true}) {
|
||||
for (auto dst_f16 : {false, true}) {
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
|
||||
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
|
||||
for (auto rte : {false, true}) {
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
|
||||
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -587,16 +594,19 @@ void process_shaders() {
|
||||
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
}
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
@@ -648,6 +658,8 @@ void process_shaders() {
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
|
||||
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
@@ -702,11 +714,59 @@ void write_output_files() {
|
||||
std::remove(path.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (const char *op : {"add", "sub", "mul", "div"}) {
|
||||
fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
|
||||
fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
|
||||
fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
|
||||
fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
|
||||
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
|
||||
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
|
||||
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
|
||||
std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
|
||||
for (uint32_t t0 = 0; t0 < 2; ++t0) {
|
||||
if (t0 == 0) {
|
||||
data += "{";
|
||||
len += "{";
|
||||
}
|
||||
for (uint32_t t1 = 0; t1 < 2; ++t1) {
|
||||
if (t1 == 0) {
|
||||
data += "{";
|
||||
len += "{";
|
||||
}
|
||||
for (uint32_t t2 = 0; t2 < 2; ++t2) {
|
||||
if (t2 == 0) {
|
||||
data += "{";
|
||||
len += "{";
|
||||
}
|
||||
for (uint32_t rte = 0; rte < 2; ++rte) {
|
||||
if (rte == 0) {
|
||||
data += "{";
|
||||
len += "{";
|
||||
}
|
||||
data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
|
||||
len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
|
||||
data += "_data,";
|
||||
len += "_len,";
|
||||
if (rte == 1) {
|
||||
data += "}, ";
|
||||
len += "}, ";
|
||||
}
|
||||
}
|
||||
if (t2 == 1) {
|
||||
data += "}, ";
|
||||
len += "}, ";
|
||||
}
|
||||
}
|
||||
if (t1 == 1) {
|
||||
data += "}, ";
|
||||
len += "}, ";
|
||||
}
|
||||
}
|
||||
if (t0 == 1) {
|
||||
data += "};\n";
|
||||
len += "};\n";
|
||||
}
|
||||
}
|
||||
fprintf(src, data.c_str());
|
||||
fprintf(src, len.c_str());
|
||||
}
|
||||
fclose(hdr);
|
||||
fclose(src);
|
||||
|
||||
+23
-5
@@ -3069,12 +3069,14 @@ static struct ggml_tensor * ggml_scale_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_is_padded_1d(a));
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, &s, sizeof(s));
|
||||
float params[2] = { s, b };
|
||||
ggml_set_op_params(result, ¶ms, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_SCALE;
|
||||
result->src[0] = a;
|
||||
@@ -3086,14 +3088,30 @@ struct ggml_tensor * ggml_scale(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s) {
|
||||
return ggml_scale_impl(ctx, a, s, false);
|
||||
return ggml_scale_impl(ctx, a, s, 0.0, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_scale_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s) {
|
||||
return ggml_scale_impl(ctx, a, s, true);
|
||||
return ggml_scale_impl(ctx, a, s, 0.0, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_scale_bias(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b) {
|
||||
return ggml_scale_impl(ctx, a, s, b, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_scale_bias_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float s,
|
||||
float b) {
|
||||
return ggml_scale_impl(ctx, a, s, b, true);
|
||||
}
|
||||
|
||||
// ggml_set
|
||||
@@ -5777,7 +5795,7 @@ static void ggml_compute_backward(
|
||||
} break;
|
||||
case GGML_OP_MEAN: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
|
||||
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_REPEAT: {
|
||||
@@ -5854,7 +5872,7 @@ static void ggml_compute_backward(
|
||||
if (src0_needs_grads) {
|
||||
float s;
|
||||
memcpy(&s, tensor->op_params, sizeof(float));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SET: {
|
||||
|
||||
+8
-1
@@ -631,7 +631,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
|
||||
size_t padded_size = GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
|
||||
if (SIZE_MAX - ctx->size < padded_size) {
|
||||
GGML_LOG_ERROR("%s: tensor '%s' size overflow, cannot accumulate size %zu + %zu\n",
|
||||
__func__, ti.t.name, ctx->size, padded_size);
|
||||
gguf_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->size += padded_size;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -187,6 +187,9 @@ class Keys:
|
||||
class Classifier:
|
||||
OUTPUT_LABELS = "{arch}.classifier.output_labels"
|
||||
|
||||
class ShortConv:
|
||||
L_CACHE = "{arch}.shortconv.l_cache"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
PRE = "tokenizer.ggml.pre"
|
||||
@@ -288,6 +291,7 @@ class MODEL_ARCH(IntEnum):
|
||||
LLAMA4 = auto()
|
||||
DECI = auto()
|
||||
FALCON = auto()
|
||||
FALCON_H1 = auto()
|
||||
BAICHUAN = auto()
|
||||
GROK = auto()
|
||||
GPT2 = auto()
|
||||
@@ -313,6 +317,7 @@ class MODEL_ARCH(IntEnum):
|
||||
PHI3 = auto()
|
||||
PHIMOE = auto()
|
||||
PLAMO = auto()
|
||||
PLAMO2 = auto()
|
||||
CODESHELL = auto()
|
||||
ORION = auto()
|
||||
INTERNLM2 = auto()
|
||||
@@ -329,6 +334,7 @@ class MODEL_ARCH(IntEnum):
|
||||
ARWKV7 = auto()
|
||||
MAMBA = auto()
|
||||
MAMBA2 = auto()
|
||||
JAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
COHERE2 = auto()
|
||||
@@ -350,6 +356,7 @@ class MODEL_ARCH(IntEnum):
|
||||
EXAONE = auto()
|
||||
GRANITE = auto()
|
||||
GRANITE_MOE = auto()
|
||||
GRANITE_HYBRID = auto()
|
||||
CHAMELEON = auto()
|
||||
WAVTOKENIZER_DEC = auto()
|
||||
PLM = auto()
|
||||
@@ -357,6 +364,10 @@ class MODEL_ARCH(IntEnum):
|
||||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
ERNIE4_5 = auto()
|
||||
HUNYUAN_MOE = auto()
|
||||
SMOLLM3 = auto()
|
||||
LFM2 = auto()
|
||||
DREAM = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
@@ -429,7 +440,10 @@ class MODEL_TENSOR(IntEnum):
|
||||
SSM_CONV1D = auto()
|
||||
SSM_X = auto()
|
||||
SSM_DT = auto()
|
||||
SSM_DT_NORM = auto()
|
||||
SSM_A = auto()
|
||||
SSM_B_NORM = auto()
|
||||
SSM_C_NORM = auto()
|
||||
SSM_D = auto()
|
||||
SSM_NORM = auto()
|
||||
SSM_OUT = auto()
|
||||
@@ -525,6 +539,9 @@ class MODEL_TENSOR(IntEnum):
|
||||
POSNET_ATTN_K = auto()
|
||||
POSNET_ATTN_V = auto()
|
||||
POSNET_ATTN_OUT = auto()
|
||||
SHORTCONV_CONV = auto()
|
||||
SHORTCONV_INPROJ = auto()
|
||||
SHORTCONV_OUTPROJ = auto()
|
||||
# vision
|
||||
V_MMPROJ = auto()
|
||||
V_MMPROJ_FC = auto()
|
||||
@@ -616,6 +633,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.PHI3: "phi3",
|
||||
MODEL_ARCH.PHIMOE: "phimoe",
|
||||
MODEL_ARCH.PLAMO: "plamo",
|
||||
MODEL_ARCH.PLAMO2: "plamo2",
|
||||
MODEL_ARCH.CODESHELL: "codeshell",
|
||||
MODEL_ARCH.ORION: "orion",
|
||||
MODEL_ARCH.INTERNLM2: "internlm2",
|
||||
@@ -632,6 +650,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.ARWKV7: "arwkv7",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.MAMBA2: "mamba2",
|
||||
MODEL_ARCH.JAMBA: "jamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
MODEL_ARCH.COHERE2: "cohere2",
|
||||
@@ -653,6 +672,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.EXAONE: "exaone",
|
||||
MODEL_ARCH.GRANITE: "granite",
|
||||
MODEL_ARCH.GRANITE_MOE: "granitemoe",
|
||||
MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
|
||||
MODEL_ARCH.CHAMELEON: "chameleon",
|
||||
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
|
||||
MODEL_ARCH.PLM: "plm",
|
||||
@@ -660,6 +680,11 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||
MODEL_ARCH.FALCON_H1: "falcon-h1",
|
||||
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
|
||||
MODEL_ARCH.SMOLLM3: "smollm3",
|
||||
MODEL_ARCH.LFM2: "lfm2",
|
||||
MODEL_ARCH.DREAM: "dream",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
@@ -732,7 +757,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
||||
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
||||
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
|
||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
|
||||
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
@@ -828,6 +856,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
|
||||
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
|
||||
MODEL_TENSOR.SHORTCONV_CONV: "blk.{bid}.shortconv.conv",
|
||||
MODEL_TENSOR.SHORTCONV_INPROJ: "blk.{bid}.shortconv.in_proj",
|
||||
MODEL_TENSOR.SHORTCONV_OUTPROJ: "blk.{bid}.shortconv.out_proj",
|
||||
# vision
|
||||
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
|
||||
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
|
||||
@@ -1260,6 +1291,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.DREAM: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.QWEN2VL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
@@ -1342,6 +1388,36 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.PLAMO2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.SSM_IN,
|
||||
MODEL_TENSOR.SSM_CONV1D,
|
||||
MODEL_TENSOR.SSM_X,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_D,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
MODEL_TENSOR.SSM_DT_NORM,
|
||||
MODEL_TENSOR.SSM_B_NORM,
|
||||
MODEL_TENSOR.SSM_C_NORM,
|
||||
],
|
||||
MODEL_ARCH.GPT2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
@@ -1732,6 +1808,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
],
|
||||
MODEL_ARCH.JAMBA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.SSM_IN,
|
||||
MODEL_TENSOR.SSM_CONV1D,
|
||||
MODEL_TENSOR.SSM_X,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_DT_NORM,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_B_NORM,
|
||||
MODEL_TENSOR.SSM_C_NORM,
|
||||
MODEL_TENSOR.SSM_D,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.XVERSE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
@@ -2101,6 +2205,36 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.GRANITE_HYBRID: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.SSM_IN,
|
||||
MODEL_TENSOR.SSM_CONV1D,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_D,
|
||||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
# MoE
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
# Dense
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.CHAMELEON: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
@@ -2211,6 +2345,95 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.FALCON_H1: [
|
||||
# Token embedding
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
||||
# Input layernorm
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
|
||||
# Attention components
|
||||
MODEL_TENSOR.ATTN_Q, # Query projection
|
||||
MODEL_TENSOR.ATTN_K, # Key projection
|
||||
MODEL_TENSOR.ATTN_V, # Value projection
|
||||
MODEL_TENSOR.ATTN_OUT, # Output projection
|
||||
|
||||
# SSM components (Mamba2 specific)
|
||||
MODEL_TENSOR.SSM_IN, # Input projection for SSM
|
||||
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
|
||||
MODEL_TENSOR.SSM_DT, # Delta time projection
|
||||
MODEL_TENSOR.SSM_A, # A parameter (log form)
|
||||
MODEL_TENSOR.SSM_D, # D parameter
|
||||
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
|
||||
MODEL_TENSOR.SSM_OUT, # Output projection
|
||||
|
||||
# Pre-feedforward layernorm
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
|
||||
# Feed-forward network components
|
||||
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
|
||||
MODEL_TENSOR.FFN_DOWN, # Down projection
|
||||
MODEL_TENSOR.FFN_UP, # Up projection
|
||||
|
||||
# Post-feedforward layernorm
|
||||
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
|
||||
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
|
||||
],
|
||||
MODEL_ARCH.HUNYUAN_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.SMOLLM3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.LFM2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.SHORTCONV_CONV,
|
||||
MODEL_TENSOR.SHORTCONV_INPROJ,
|
||||
MODEL_TENSOR.SHORTCONV_OUTPROJ,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM, # operator_norm
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
||||
@@ -648,6 +648,9 @@ class GGUFWriter:
|
||||
def add_convnext_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_shortconv_l_cache(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ShortConv.L_CACHE.format(arch=self.arch), length)
|
||||
|
||||
def add_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
|
||||
@@ -234,6 +234,8 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
|
||||
markdown_content += '## Key Value Metadata Store\n\n'
|
||||
markdown_content += f'There are {len(reader.fields)} key-value pairs in this file\n'
|
||||
markdown_content += '\n'
|
||||
total_model_bytes = 0
|
||||
total_model_elements = 0
|
||||
|
||||
kv_dump_table: list[dict[str, str | int]] = []
|
||||
for n, field in enumerate(reader.fields.values(), 1):
|
||||
@@ -377,6 +379,8 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
|
||||
tensors = tensor_groups[group]
|
||||
group_elements = sum(tensor.n_elements for tensor in tensors)
|
||||
group_percentage = group_elements / total_elements * 100
|
||||
total_group_bytes = 0
|
||||
total_group_elements = 0
|
||||
markdown_content += f"### <a name=\"{group.replace('.', '_')}\">{translate_tensor_name(group)} Tensor Group : {element_count_rounded_notation(group_elements)} Elements</a>\n\n"
|
||||
|
||||
# Precalculate column sizing for visual consistency
|
||||
@@ -397,7 +401,13 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
|
||||
element_count_est = f"({element_count_rounded_notation(tensor.n_elements):>{prettify_element_est_count_size}})"
|
||||
element_count_string = f"{element_count_est} {tensor.n_elements:>{prettify_element_count_size}}"
|
||||
type_name_string = f"{tensor.tensor_type.name}"
|
||||
tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string})
|
||||
if tensor.n_elements > 0:
|
||||
bpw = (tensor.n_bytes * 8) / tensor.n_elements
|
||||
else:
|
||||
bpw = float('nan')
|
||||
tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string, "bpw": f"{bpw:.4f}"})
|
||||
total_group_bytes += tensor.n_bytes
|
||||
total_group_elements += tensor.n_elements
|
||||
|
||||
tensor_dump_table_header_map = [
|
||||
{'key_name':'t_id', 'header_name':'T_ID', 'align':'right'},
|
||||
@@ -406,6 +416,7 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
|
||||
{'key_name':'element_count', 'header_name':'Elements', 'align':'left'},
|
||||
{'key_name':'pretty_dimension', 'header_name':'Shape', 'align':'left'},
|
||||
{'key_name':'tensor_type', 'header_name':'Type', 'align':'left'},
|
||||
{'key_name':'bpw', 'header_name':'BPW', 'align':'right'},
|
||||
]
|
||||
|
||||
markdown_content += markdown_table_with_alignment_support(tensor_dump_table_header_map, tensor_dump_table)
|
||||
@@ -413,8 +424,20 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None
|
||||
markdown_content += "\n"
|
||||
markdown_content += f"- Total elements in {group}: ({element_count_rounded_notation(group_elements):>4}) {group_elements}\n"
|
||||
markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n"
|
||||
if total_group_elements > 0:
|
||||
total_group_bpw = (total_group_bytes * 8) / total_group_elements
|
||||
markdown_content += f"- Bits per Weight (BPW) for {group}: {total_group_bpw:.4f} bits\n"
|
||||
else:
|
||||
markdown_content += f"- Bits per Weight (BPW) for {group}: undefined (no elements)\n"
|
||||
markdown_content += "\n\n"
|
||||
total_model_bytes += total_group_bytes
|
||||
total_model_elements += total_group_elements
|
||||
|
||||
if total_model_elements > 0:
|
||||
total_model_bpw = (total_model_bytes * 8) / total_model_elements
|
||||
markdown_content += f"Total BPW for {os.path.basename(args.model)}: {total_model_bpw:.4f} bits"
|
||||
else:
|
||||
markdown_content += f"Total BPW for {os.path.basename(args.model)}: undefined (no elements)"
|
||||
print(markdown_content) # noqa: NP100
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ class TensorNameMap:
|
||||
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
|
||||
"transformer.word_embeddings", # falcon
|
||||
"word_embeddings", # bloom
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert nomic-bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
@@ -50,6 +50,7 @@ class TensorNameMap:
|
||||
"model.pre_ln", # rwkv7
|
||||
"model.layers.0.pre_norm", # rwkv7
|
||||
"backbone.norm", # wavtokenizer
|
||||
"model.embedding_norm", # lfm2
|
||||
),
|
||||
|
||||
# Position embeddings
|
||||
@@ -62,7 +63,7 @@ class TensorNameMap:
|
||||
# Output
|
||||
MODEL_TENSOR.OUTPUT: (
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe plamo2
|
||||
"output", # llama-pth bloom internlm2
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"lm_head.linear", # phi2
|
||||
@@ -76,7 +77,7 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.OUTPUT_NORM: (
|
||||
"gpt_neox.final_layer_norm", # gptneox
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
|
||||
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe
|
||||
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe plamo2
|
||||
"norm", # llama-pth
|
||||
"transformer.norm_f", # mpt dbrx
|
||||
"ln_f", # refact bloom qwen gpt2
|
||||
@@ -118,13 +119,14 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"h.{bid}.input_layernorm", # bloom
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe granite-hybrid
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln1", # yi
|
||||
"h.{bid}.ln_1", # gpt2
|
||||
"transformer.h.{bid}.ln", # phi2
|
||||
"model.layers.layers.{bid}.norm", # plamo
|
||||
"model.layers.layers.{bid}.pre_mixer_norm", # plamo2
|
||||
"model.layers.{bid}.attention_norm", # internlm2
|
||||
"model.layers.{bid}.norm", # mamba-qbert
|
||||
"backbone.layers.{bid}.norm", # mamba
|
||||
@@ -136,6 +138,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
"model.layers.{bid}.input_layernorm", # llama4
|
||||
"transformer_encoder.{bid}.attention_norm", # neobert
|
||||
"model.layers.{bid}.operator_norm", # lfm2
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@@ -161,6 +164,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||
"encoder.layers.{bid}.mixer.Wqkv", # jina
|
||||
"model.layers.{bid}.self_attn.qkv_proj", # phi3
|
||||
"model.layers.layers.{bid}.mixer.qkv_proj", # plamo2
|
||||
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
|
||||
"transformer.layers.{bid}.attn.qkv_proj", # openelm
|
||||
"transformer_encoder.{bid}.qkv", # neobert
|
||||
@@ -220,6 +224,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"model.layers.{bid}.self_attn.out_proj", # lfm2
|
||||
"model.layers.{bid}.self_attn.linear_attn", # deci
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
@@ -230,6 +235,7 @@ class TensorNameMap:
|
||||
"h.{bid}.attn.c_proj", # gpt2
|
||||
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||
"model.layers.layers.{bid}.mixer.o_proj", # plamo2
|
||||
"model.layers.{bid}.attention.wo", # internlm2
|
||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||
"encoder.layers.{bid}.mixer.out_proj", # jina
|
||||
@@ -252,8 +258,9 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_POST_NORM: (
|
||||
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
|
||||
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
|
||||
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
|
||||
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
|
||||
"model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2
|
||||
),
|
||||
|
||||
# Rotary embeddings
|
||||
@@ -279,19 +286,25 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
"model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
|
||||
"model.layers.{bid}.pre_moe_layernorm", # mini-jamba
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
"transformer_encoder.{bid}.ffn_norm", # neobert
|
||||
"model.layers.layers.{bid}.pre_mlp_norm", # plamo2
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_PRE_NORM: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
"model.layers.{bid}.pre_ff_layernorm.weight",
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
|
||||
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
|
||||
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
|
||||
"model.layers.{bid}.feed_forward.up_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
@@ -301,8 +314,9 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
"model.layers.{bid}.feed_forward.router", # llama4
|
||||
"model.layers.{bid}.feed_forward.router", # llama4 jamba
|
||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
@@ -334,6 +348,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.fc1", # phi2
|
||||
"model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414
|
||||
"model.layers.layers.{bid}.mlp.up_proj", # plamo
|
||||
"model.layers.layers.{bid}.mlp.gate_up_proj", # plamo2
|
||||
"model.layers.{bid}.feed_forward.w3", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
|
||||
"encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe
|
||||
@@ -344,7 +359,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
"model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid
|
||||
"transformer_encoder.{bid}.ffn.w12", # neobert
|
||||
),
|
||||
|
||||
@@ -362,6 +377,8 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.down_proj",
|
||||
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
@@ -382,7 +399,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
@@ -398,6 +415,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
@@ -427,7 +445,7 @@ class TensorNameMap:
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
"model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid
|
||||
"transformer_encoder.{bid}.ffn.w3", # neobert
|
||||
),
|
||||
|
||||
@@ -447,24 +465,29 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
|
||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.q_norm", # openelm
|
||||
"model.layers.layers.{bid}.mixer.q", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
|
||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.k_norm", # openelm
|
||||
"model.layers.layers.{bid}.mixer.k", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
@@ -545,42 +568,77 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_IN: (
|
||||
"model.layers.{bid}.in_proj",
|
||||
"backbone.layers.{bid}.mixer.in_proj",
|
||||
"model.layers.{bid}.in_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.in_proj", # mamba
|
||||
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_CONV1D: (
|
||||
"model.layers.{bid}.conv1d",
|
||||
"backbone.layers.{bid}.mixer.conv1d",
|
||||
"model.layers.{bid}.conv1d", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.conv1d", # mamba
|
||||
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_X: (
|
||||
"model.layers.{bid}.x_proj",
|
||||
"backbone.layers.{bid}.mixer.x_proj",
|
||||
"model.layers.{bid}.x_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.x_proj", # mamba
|
||||
"model.layers.{bid}.mamba.x_proj", # jamba
|
||||
"model.layers.layers.{bid}.mixer.bcdt_proj", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_DT: (
|
||||
"model.layers.{bid}.dt_proj",
|
||||
"backbone.layers.{bid}.mixer.dt_proj",
|
||||
"model.layers.{bid}.dt_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.dt_proj", # mamba
|
||||
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_DT_NORM: (
|
||||
"model.layers.{bid}.mamba.dt_layernorm", # jamba
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_A: (
|
||||
"model.layers.{bid}.A_log",
|
||||
"backbone.layers.{bid}.mixer.A_log",
|
||||
"model.layers.{bid}.A_log", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.A_log", # mamba
|
||||
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.A_log", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_B_NORM: (
|
||||
"model.layers.{bid}.mamba.b_layernorm", # jamba
|
||||
"model.layers.{bid}.mamba.B_layernorm", # mini-jamba
|
||||
"model.layers.layers.{bid}.mixer.B_norm.weight", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_C_NORM: (
|
||||
"model.layers.{bid}.mamba.c_layernorm", # jamba
|
||||
"model.layers.{bid}.mamba.C_layernorm", # mini-jamba
|
||||
"model.layers.layers.{bid}.mixer.C_norm.weight", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_D: (
|
||||
"model.layers.{bid}.D",
|
||||
"backbone.layers.{bid}.mixer.D",
|
||||
"model.layers.{bid}.D", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.D", # mamba
|
||||
"model.layers.{bid}.mamba.D", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.D", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_DT_NORM: (
|
||||
"model.layers.layers.{bid}.mixer.dt_norm.weight", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_NORM: (
|
||||
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
|
||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_OUT: (
|
||||
"model.layers.{bid}.out_proj",
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
"model.layers.{bid}.out_proj", # mamba-hf
|
||||
"backbone.layers.{bid}.mixer.out_proj", # mamba
|
||||
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
|
||||
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W0: (
|
||||
@@ -982,6 +1040,18 @@ class TensorNameMap:
|
||||
"backbone.posnet.{bid}.proj_out", # wavtokenizer
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SHORTCONV_CONV: (
|
||||
"model.layers.{bid}.conv.conv",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SHORTCONV_INPROJ: (
|
||||
"model.layers.{bid}.conv.in_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SHORTCONV_OUTPROJ: (
|
||||
"model.layers.{bid}.conv.out_proj",
|
||||
),
|
||||
|
||||
#############################################################################
|
||||
## Vision encoder
|
||||
|
||||
|
||||
+12
-47
@@ -71,52 +71,13 @@ extern "C" {
|
||||
typedef int32_t llama_seq_id;
|
||||
|
||||
enum llama_vocab_type {
|
||||
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
|
||||
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
||||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
||||
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||
};
|
||||
|
||||
// pre-tokenization types
|
||||
enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
||||
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
||||
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
||||
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
||||
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
||||
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
||||
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
||||
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
||||
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
||||
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
||||
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
||||
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
||||
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
||||
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
||||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
|
||||
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
|
||||
LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
|
||||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
||||
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||
LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
@@ -374,6 +335,9 @@ extern "C" {
|
||||
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
||||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
@@ -764,7 +728,7 @@ extern "C" {
|
||||
// - lazily on next llama_decode()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
DEPRECATED(void llama_kv_self_seq_div(
|
||||
DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
@@ -1044,6 +1008,7 @@ extern "C" {
|
||||
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
|
||||
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
|
||||
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
|
||||
LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
|
||||
|
||||
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
||||
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
{%- if not add_generation_prompt is defined -%}
|
||||
{%- set add_generation_prompt = true -%}
|
||||
{%- endif -%}
|
||||
{%- set ns = namespace(system_prompt='') -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{%- set ns.system_prompt = message['content'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{bos_token}}
|
||||
{%- if ns.system_prompt != '' -%}
|
||||
{{- 'System: ' + ns.system_prompt + '\n\n' -}}
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- 'User: ' + message['content']|trim + '\n\n' -}}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is not none -%}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if '</think>' in content -%}
|
||||
{%- set content = content.split('</think>')[-1] -%}
|
||||
{%- endif -%}
|
||||
{{- 'Assistant: ' + content|trim + '\n\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- 'Assistant:' -}}
|
||||
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||
{{- ' <think>\n</think>' }}
|
||||
{%- endif %}
|
||||
{%- if enable_thinking is defined and enable_thinking is true %}
|
||||
{{- ' <think>' }}
|
||||
{%- endif %}
|
||||
{%- endif -%}
|
||||
@@ -0,0 +1,43 @@
|
||||
{%- if tools -%}
|
||||
<|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|>
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if loop.first and messages[0]['role'] != 'system' -%}
|
||||
<|im_system|>system<|im_middle|>You are a helpful assistant<|im_end|>
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
<|im_system|>system<|im_middle|>
|
||||
{%- elif message['role'] == 'user' -%}
|
||||
<|im_user|>user<|im_middle|>
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
<|im_assistant|>assistant<|im_middle|>
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
<|im_system|>tool<|im_middle|>
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
|
||||
{%- if message['content'] -%}{{ message['content'] }}{%- endif -%}
|
||||
<|tool_calls_section_begin|>
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set func_name = tool_call['function']['name'] -%}
|
||||
{%- set formatted_id = 'functions.' + func_name + ':' + loop.index0|string -%}
|
||||
<|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{{ tool_call['function']['arguments'] | tojson}}<|tool_call_end|>
|
||||
{%- endfor -%}
|
||||
<|tool_calls_section_end|>
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
## Return of {{ message.tool_call_id }}\n{{ message['content'] }}
|
||||
{%- elif message['content'] is string -%}
|
||||
{{ message['content'] }}
|
||||
{%- elif message['content'] is not none -%}
|
||||
{% for content in message['content'] -%}
|
||||
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
||||
<|media_start|>image<|media_content|><|media_pad|><|media_end|>
|
||||
{% else -%}
|
||||
{{ content['text'] }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
<|im_end|>
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
<|im_assistant|>assistant<|im_middle|>
|
||||
{%- endif -%}
|
||||
@@ -3,6 +3,7 @@
|
||||
-r ../tools/server/tests/requirements.txt
|
||||
|
||||
-r ./requirements-compare-llama-bench.txt
|
||||
-r ./requirements-server-bench.txt
|
||||
-r ./requirements-pydantic.txt
|
||||
-r ./requirements-test-tokenizer-random.txt
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
datasets~=3.2.0
|
||||
matplotlib~=3.10.0
|
||||
numpy~=1.26.4
|
||||
requests~=2.32.3
|
||||
tqdm~=4.67.1
|
||||
Executable
+196
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This script parses docs/ops/*.csv and creates the ops.md, which is a table documenting supported operations on various ggml backends.
|
||||
"""
|
||||
import csv
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class DocsGenerator:
|
||||
def __init__(self, ggml_root: str, output_filename: str = "ops.md"):
|
||||
self.ggml_root = Path(ggml_root)
|
||||
self.ops_dir = self.ggml_root / "docs" / "ops"
|
||||
self.output_filename = output_filename
|
||||
self.backend_support: dict[str, dict[str, list[bool]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
self.all_operations: set[str] = set()
|
||||
self.all_backends: set[str] = set()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_support_files(self) -> None:
|
||||
if not self.ops_dir.exists():
|
||||
self.logger.warning(f"ops directory not found: {self.ops_dir}")
|
||||
return
|
||||
|
||||
self.logger.info(f"Parsing support files from {self.ops_dir}...")
|
||||
|
||||
for support_file in self.ops_dir.glob("*.csv"):
|
||||
self.logger.info(f" Reading: {support_file.name}")
|
||||
self._parse_support_file(support_file)
|
||||
|
||||
def _parse_support_file(self, file_path: Path) -> None:
|
||||
try:
|
||||
with open(file_path, "r", newline='') as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
for row in reader:
|
||||
# Skip rows that don't have support mode
|
||||
if row.get('test_mode') != 'support':
|
||||
continue
|
||||
|
||||
backend_name = row.get('backend_name', '').strip()
|
||||
operation = row.get('op_name', '').strip()
|
||||
supported_str = row.get('error_message', '').strip() # "yes" or "no"
|
||||
backend_reg_name = row.get('backend_reg_name', '').strip()
|
||||
|
||||
# Skip invalid or error operations
|
||||
if not operation or not backend_name or operation in [
|
||||
"CONTEXT_ERROR",
|
||||
"BUILD_ERROR",
|
||||
]:
|
||||
continue
|
||||
|
||||
is_supported = supported_str.lower() == "yes"
|
||||
|
||||
# Use backend_reg_name for grouping, fallback to backend_name
|
||||
backend_key = backend_reg_name if backend_reg_name else backend_name
|
||||
|
||||
self.all_backends.add(backend_key)
|
||||
self.backend_support[backend_key][operation].append(is_supported)
|
||||
self.all_operations.add(operation)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f" Error parsing {file_path}: {e}")
|
||||
|
||||
def get_backend_support_status(self, backend: str, operation: str) -> str:
|
||||
support_list = self.backend_support[backend].get(operation, [])
|
||||
|
||||
if not support_list:
|
||||
return "unsupported"
|
||||
|
||||
all_supported = all(support_list)
|
||||
any_supported = any(support_list)
|
||||
|
||||
if all_supported:
|
||||
return "supported"
|
||||
elif any_supported:
|
||||
return "partially supported"
|
||||
else:
|
||||
return "unsupported"
|
||||
|
||||
def get_support_status(self, operation: str) -> str:
|
||||
if operation not in self.all_operations:
|
||||
return "unsupported"
|
||||
|
||||
support_count = 0
|
||||
total_backends = len(self.all_backends)
|
||||
|
||||
for backend in self.all_backends:
|
||||
if self.backend_support[backend].get(operation, False):
|
||||
support_count += 1
|
||||
|
||||
if support_count == 0:
|
||||
return "unsupported"
|
||||
elif support_count == total_backends:
|
||||
return "supported"
|
||||
else:
|
||||
return "partially supported"
|
||||
|
||||
def get_support_symbol(self, status: str) -> str:
|
||||
symbols = {"supported": "✅", "partially supported": "🟡", "unsupported": "❌"}
|
||||
return symbols.get(status, "❓")
|
||||
|
||||
def generate_markdown(self) -> str:
|
||||
lines = []
|
||||
|
||||
lines.append("# GGML Operations")
|
||||
lines.append("")
|
||||
lines.append("List of GGML operations and backend support status.")
|
||||
lines.append("")
|
||||
lines.append("Legend:")
|
||||
lines.append("- ✅ Fully supported by this backend")
|
||||
lines.append("- 🟡 Partially supported by this backend")
|
||||
lines.append("- ❌ Not supported by this backend")
|
||||
lines.append("")
|
||||
|
||||
backends = sorted(self.all_backends)
|
||||
header = "| Operation |"
|
||||
for backend in backends:
|
||||
header += f" {backend} |"
|
||||
|
||||
separator = "|-----------|"
|
||||
for _ in backends:
|
||||
separator += "------|"
|
||||
|
||||
lines.append(header)
|
||||
lines.append(separator)
|
||||
|
||||
sorted_operations = sorted(self.all_operations)
|
||||
|
||||
for operation in sorted_operations:
|
||||
row = f"| {operation:>32} |"
|
||||
|
||||
for backend in backends:
|
||||
status = self.get_backend_support_status(backend, operation)
|
||||
if status == "supported":
|
||||
symbol = "✅"
|
||||
elif status == "partially supported":
|
||||
symbol = "🟡"
|
||||
else:
|
||||
symbol = "❌"
|
||||
row += f" {symbol} |"
|
||||
|
||||
lines.append(row)
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def run(self) -> None:
|
||||
self.logger.info("Parsing GGML operation support files...")
|
||||
self.parse_support_files()
|
||||
|
||||
if not self.all_operations:
|
||||
self.logger.error(
|
||||
"No operations found. Make sure to run test-backend-ops support --output csv > docs/ops/file.csv first."
|
||||
)
|
||||
return
|
||||
|
||||
self.logger.info(
|
||||
f"Found {len(self.all_operations)} operations across {len(self.all_backends)} backends"
|
||||
)
|
||||
|
||||
self.logger.info("Generating markdown...")
|
||||
markdown_content = self.generate_markdown()
|
||||
|
||||
docs_dir = self.ggml_root / "docs"
|
||||
docs_dir.mkdir(exist_ok=True)
|
||||
|
||||
ops_file = docs_dir / self.output_filename
|
||||
with open(ops_file, "w") as f:
|
||||
f.write(markdown_content)
|
||||
|
||||
self.logger.info(f"Generated: {ops_file}")
|
||||
self.logger.info(f"Operations: {len(self.all_operations)}")
|
||||
self.logger.info(f"Backends: {len(self.all_backends)}")
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
output_filename = sys.argv[1]
|
||||
else:
|
||||
output_filename = "ops.md"
|
||||
|
||||
generator = DocsGenerator(".", output_filename)
|
||||
generator.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+265
@@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
from time import sleep, time
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import requests
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger("server-bench")
|
||||
|
||||
|
||||
def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
|
||||
ret = []
|
||||
if dataset_name.lower() == "mmlu":
|
||||
logger.info("Loading MMLU dataset...")
|
||||
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
|
||||
else:
|
||||
return None
|
||||
if n_prompts >= 0:
|
||||
ret = ret[:n_prompts]
|
||||
return ret
|
||||
|
||||
|
||||
def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
|
||||
assert n_prompts >= 0
|
||||
ret: list[int] = []
|
||||
for i in range(n_prompts):
|
||||
random.seed(13 * i + 0)
|
||||
ret.append(random.randint(prompt_length_min, prompt_length_max))
|
||||
return ret
|
||||
|
||||
|
||||
def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
|
||||
return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
|
||||
|
||||
|
||||
def get_server(path_server: str, path_log: Optional[str]) -> dict:
|
||||
logger.info("Starting the llama.cpp server...")
|
||||
hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
|
||||
port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
|
||||
address: str = f"http://{hostname}:{port}"
|
||||
|
||||
fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
|
||||
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
|
||||
|
||||
n_failures: int = 0
|
||||
while True:
|
||||
try:
|
||||
sleep(1.0)
|
||||
exit_code = process.poll()
|
||||
if exit_code is not None:
|
||||
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
|
||||
response = requests.get(f"{address}/health")
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except requests.ConnectionError:
|
||||
n_failures += 1
|
||||
if n_failures >= 10:
|
||||
raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
|
||||
|
||||
return {"process": process, "address": address, "fout": fout}
|
||||
|
||||
|
||||
def get_prompt_length(data: dict) -> int:
|
||||
session = data["session"]
|
||||
server_address: str = data["server_address"]
|
||||
|
||||
response = session.post(
|
||||
f"{server_address}/apply-template",
|
||||
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
|
||||
prompt: str = json.loads(response.text)["prompt"]
|
||||
response = session.post(
|
||||
f"{server_address}/tokenize",
|
||||
json={"content": prompt, "add_special": True}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
|
||||
tokens: list[str] = json.loads(response.text)["tokens"]
|
||||
return len(tokens)
|
||||
|
||||
|
||||
def send_prompt(data: dict) -> tuple[float, list[float]]:
|
||||
session = data["session"]
|
||||
server_address: str = data["server_address"]
|
||||
|
||||
t_submit = time()
|
||||
if data["synthetic_prompt"]:
|
||||
json_data: dict = {
|
||||
"prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
|
||||
"seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
|
||||
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
|
||||
else:
|
||||
response = session.post(
|
||||
f"{server_address}/apply-template",
|
||||
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
|
||||
prompt: str = json.loads(response.text)["prompt"]
|
||||
|
||||
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
|
||||
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
|
||||
|
||||
token_arrival_times: list[float] = []
|
||||
for line in response.iter_lines(decode_unicode=False):
|
||||
if not line.startswith(b"data: "):
|
||||
continue
|
||||
token_arrival_times.append(time())
|
||||
token_arrival_times = token_arrival_times[:-1]
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
|
||||
|
||||
return (t_submit, token_arrival_times)
|
||||
|
||||
|
||||
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
|
||||
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
|
||||
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
|
||||
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
|
||||
if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
|
||||
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
|
||||
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
|
||||
if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
|
||||
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
|
||||
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
|
||||
|
||||
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
|
||||
prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
|
||||
synthetic_prompts: bool = prompts is None
|
||||
prompt_n = []
|
||||
|
||||
if synthetic_prompts:
|
||||
prompt_source_split: list[str] = prompt_source.split("-")
|
||||
assert len(prompt_source_split) == 3
|
||||
assert prompt_source_split[0].lower() == "rng"
|
||||
prompt_length_min: int = int(prompt_source_split[1])
|
||||
prompt_length_max: int = int(prompt_source_split[2])
|
||||
logger.info("Generating random prompts...")
|
||||
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
|
||||
prompts = get_prompts_rng(prompt_n)
|
||||
else:
|
||||
n_predict_min = n_predict
|
||||
|
||||
if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
|
||||
context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
|
||||
context_total: int = context_per_slot * parallel
|
||||
os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
|
||||
logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
|
||||
|
||||
server: Optional[dict] = None
|
||||
session = None
|
||||
try:
|
||||
server = get_server(path_server, path_log)
|
||||
server_address: str = server["address"]
|
||||
|
||||
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
|
||||
session = requests.Session()
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
data: list[dict] = []
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
random.seed(13 * i + 1)
|
||||
data.append({
|
||||
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
|
||||
"n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
|
||||
|
||||
if not synthetic_prompts:
|
||||
logger.info("Getting the prompt lengths...")
|
||||
prompt_n = [get_prompt_length(d) for d in data]
|
||||
|
||||
logger.info("Starting the benchmark...\n")
|
||||
t0 = time()
|
||||
results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
|
||||
finally:
|
||||
if server is not None:
|
||||
server["process"].terminate()
|
||||
server["process"].wait()
|
||||
if session is not None:
|
||||
session.close()
|
||||
|
||||
prompt_t = []
|
||||
token_t = []
|
||||
depth_sum: int = 0
|
||||
for pn, (t_submit, tat) in zip(prompt_n, results):
|
||||
prompt_t.append(tat[0] - t_submit)
|
||||
token_t += tat
|
||||
n_tokens: int = len(tat)
|
||||
depth_sum += n_tokens * pn
|
||||
depth_sum += n_tokens * (n_tokens + 1) // 2
|
||||
assert len(token_t) > 0
|
||||
prompt_n = np.array(prompt_n, dtype=np.int64)
|
||||
prompt_t = np.array(prompt_t, dtype=np.float64)
|
||||
token_t = np.array(token_t, dtype=np.float64)
|
||||
|
||||
token_t -= t0
|
||||
token_t_last = np.max(token_t)
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"Benchmark duration: {token_t_last:.2f} s")
|
||||
logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
|
||||
logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
|
||||
logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
|
||||
logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
|
||||
logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
|
||||
logger.info(f"Total generated tokens: {token_t.shape[0]}")
|
||||
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
|
||||
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
|
||||
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
|
||||
logger.info("")
|
||||
logger.info(
|
||||
"The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
|
||||
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
|
||||
|
||||
plt.figure()
|
||||
plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
|
||||
plt.xlim(0, 1.05e0 * np.max(prompt_n))
|
||||
plt.ylim(0, 1.05e3 * np.max(prompt_t))
|
||||
plt.xlabel("Prompt length [tokens]")
|
||||
plt.ylabel("Time to first token [ms]")
|
||||
plt.savefig("prompt_time.png", dpi=240)
|
||||
|
||||
bin_max = np.ceil(token_t_last) + 1
|
||||
plt.figure()
|
||||
plt.hist(token_t, np.arange(0, bin_max))
|
||||
plt.xlim(0, bin_max + 1)
|
||||
plt.xlabel("Time [s]")
|
||||
plt.ylabel("Num. tokens generated per second")
|
||||
plt.savefig("gen_rate.png", dpi=240)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
|
||||
"Results are printed to console and visualized as plots (saved to current working directory). "
|
||||
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
|
||||
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
|
||||
parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
|
||||
parser.add_argument(
|
||||
"--prompt_source", type=str, default="rng-1024-2048",
|
||||
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
|
||||
"rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
|
||||
parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
|
||||
parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
|
||||
parser.add_argument(
|
||||
"--n_predict_min", type=int, default=1024,
|
||||
help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
|
||||
args = parser.parse_args()
|
||||
benchmark(**vars(args))
|
||||
@@ -1 +1 @@
|
||||
0405219965324e11a29b6aadfe22a6d66131978f
|
||||
d62df60a07ba3deeb85e5cfc9b1ee07645ff35e2
|
||||
|
||||
+231
-3
@@ -34,6 +34,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_PHI3, "phi3" },
|
||||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
{ LLM_ARCH_PLAMO2, "plamo2" },
|
||||
{ LLM_ARCH_CODESHELL, "codeshell" },
|
||||
{ LLM_ARCH_ORION, "orion" },
|
||||
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
||||
@@ -46,6 +47,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_MAMBA2, "mamba2" },
|
||||
{ LLM_ARCH_JAMBA, "jamba" },
|
||||
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
{ LLM_ARCH_COMMAND_R, "command-r" },
|
||||
{ LLM_ARCH_COHERE2, "cohere2" },
|
||||
@@ -71,6 +74,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
||||
{ LLM_ARCH_GRANITE, "granite" },
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
{ LLM_ARCH_PLM, "plm" },
|
||||
@@ -78,6 +82,10 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_LFM2, "lfm2" },
|
||||
{ LLM_ARCH_DREAM, "dream" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -150,7 +158,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
@@ -184,6 +191,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
|
||||
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
||||
|
||||
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
|
||||
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||
@@ -777,6 +786,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PLAMO2,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
|
||||
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
|
||||
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_CODESHELL,
|
||||
{
|
||||
@@ -1022,6 +1061,61 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_JAMBA,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
|
||||
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_FALCON_H1,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_XVERSE,
|
||||
{
|
||||
@@ -1582,6 +1676,43 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GRANITE_HYBRID,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
// mamba(2) ssm layers
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
// attention layers
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
// dense FFN
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
// moe FFN
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
// shared expert
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_CHAMELEON,
|
||||
{
|
||||
@@ -1694,12 +1825,90 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_SMOLLM3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LFM2,
|
||||
{
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
|
||||
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
|
||||
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DREAM,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
@@ -1778,6 +1987,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
|
||||
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
||||
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
|
||||
{LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
@@ -1858,6 +2070,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
||||
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
};
|
||||
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||
@@ -1925,9 +2140,22 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
|
||||
}
|
||||
|
||||
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||
// TODO: There are currently no hybrid models! Once there are, this will be
|
||||
// the place to identify them
|
||||
switch (arch) {
|
||||
case LLM_ARCH_JAMBA:
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
case LLM_ARCH_PLAMO2:
|
||||
case LLM_ARCH_GRANITE_HYBRID:
|
||||
case LLM_ARCH_LFM2:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_is_diffusion(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_DREAM:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
+17
-1
@@ -38,6 +38,7 @@ enum llm_arch {
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PHIMOE,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_PLAMO2,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
LLM_ARCH_INTERNLM2,
|
||||
@@ -50,6 +51,8 @@ enum llm_arch {
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_MAMBA2,
|
||||
LLM_ARCH_JAMBA,
|
||||
LLM_ARCH_FALCON_H1,
|
||||
LLM_ARCH_XVERSE,
|
||||
LLM_ARCH_COMMAND_R,
|
||||
LLM_ARCH_COHERE2,
|
||||
@@ -75,6 +78,7 @@ enum llm_arch {
|
||||
LLM_ARCH_ARWKV7,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_GRANITE_HYBRID,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_PLM,
|
||||
@@ -82,6 +86,10 @@ enum llm_arch {
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_LFM2,
|
||||
LLM_ARCH_DREAM,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -154,7 +162,6 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_LAYER_INDICES,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
@@ -223,6 +230,8 @@ enum llm_kv {
|
||||
|
||||
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
|
||||
|
||||
LLM_KV_SHORTCONV_L_CACHE,
|
||||
|
||||
// deprecated:
|
||||
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||
@@ -293,7 +302,10 @@ enum llm_tensor {
|
||||
LLM_TENSOR_SSM_CONV1D,
|
||||
LLM_TENSOR_SSM_X,
|
||||
LLM_TENSOR_SSM_DT,
|
||||
LLM_TENSOR_SSM_DT_NORM,
|
||||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_B_NORM,
|
||||
LLM_TENSOR_SSM_C_NORM,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_NORM,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
@@ -389,6 +401,9 @@ enum llm_tensor {
|
||||
LLM_TENSOR_POS_NET_ATTN_K,
|
||||
LLM_TENSOR_POS_NET_ATTN_V,
|
||||
LLM_TENSOR_POS_NET_ATTN_OUT,
|
||||
LLM_TENSOR_SHORTCONV_CONV,
|
||||
LLM_TENSOR_SHORTCONV_INPROJ,
|
||||
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
||||
};
|
||||
|
||||
enum llm_tensor_layer {
|
||||
@@ -465,3 +480,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch);
|
||||
bool llm_arch_is_hybrid (const llm_arch & arch);
|
||||
bool llm_arch_is_diffusion(const llm_arch & arch);
|
||||
|
||||
+18
-11
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
|
||||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all) {
|
||||
clear();
|
||||
|
||||
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
|
||||
// validate input batch
|
||||
//
|
||||
|
||||
if (n_seq_max > LLAMA_MAX_SEQ) {
|
||||
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (batch.token) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
||||
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
|
||||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
|
||||
|
||||
// initialize the starting position for each sequence based on the positions in the memory
|
||||
llama_pos p0[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (!memory) {
|
||||
// if no memory -> start from 0
|
||||
p0[s] = 0;
|
||||
@@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
|
||||
// compute stats
|
||||
//
|
||||
|
||||
this->n_embd = n_embd;
|
||||
this->n_embd = n_embd;
|
||||
this->n_seq_max = n_seq_max;
|
||||
|
||||
// count the outputs in this batch
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
@@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
|
||||
seq_set_map[cur].push_back(i);
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
seq_idx[s] = seq_id_unq.size();
|
||||
seq_id_unq.push_back(s);
|
||||
@@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
|
||||
// consistency checks
|
||||
//
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
@@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
|
||||
}
|
||||
|
||||
if (memory) {
|
||||
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
||||
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
||||
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
|
||||
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
|
||||
if (seq_cpl[s0][s1]) {
|
||||
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
||||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
||||
@@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
|
||||
//
|
||||
{
|
||||
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_set[s].set();
|
||||
}
|
||||
|
||||
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_pos[s] = -1;
|
||||
}
|
||||
|
||||
@@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
|
||||
ubatch.seq_id_unq.push_back(s);
|
||||
|
||||
@@ -48,6 +48,7 @@ public:
|
||||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
@@ -100,6 +101,7 @@ private:
|
||||
const uint32_t n_pos_per_embd;
|
||||
|
||||
uint32_t n_embd;
|
||||
uint32_t n_seq_max;
|
||||
uint32_t n_outputs;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user