forked from wylab/llama.cpp
Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3e3cb19f64 | |||
| 5acd455460 | |||
| 554fd578a5 | |||
| fa882fd2b1 | |||
| ffa059034c | |||
| 120bf7046d | |||
| 4258e0cfe7 | |||
| 7ea15bb64c | |||
| 9c7185dd28 | |||
| 1ee9d0b415 | |||
| 48e2fa9fb7 | |||
| 5b6913c47b | |||
| bc07349a7f | |||
| e60f241eac | |||
| e38b7c6e9e | |||
| 5016b72862 | |||
| 7049736b2d | |||
| 01d2bdc2bc | |||
| 56fc38b965 | |||
| 1fb9504eb7 | |||
| 3f750f8d76 | |||
| c515fc5771 | |||
| f9bc66c3eb | |||
| a31cf36ad9 | |||
| 81d54bbfd5 | |||
| c7be9febcb | |||
| 8415f61e23 | |||
| 2c301e91ab | |||
| 4b2dae383d | |||
| 41aac5c69b | |||
| a2fba89a42 | |||
| 20cc625edc | |||
| 11f0af5504 | |||
| a3cb04744f | |||
| 4a8fbe0a5e | |||
| 31d0ff1869 | |||
| 97870e6497 | |||
| 477a66b035 | |||
| e60f01d941 | |||
| 81086cd6a3 | |||
| 68ee98ae18 | |||
| cdb6da468c | |||
| 6d69ab3f26 |
@@ -387,6 +387,39 @@ jobs:
|
||||
cd build
|
||||
ctest -L main --verbose
|
||||
|
||||
ubuntu-24-cmake-vulkan-deb:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-24-cmake-vulkan-deb
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
|
||||
|
||||
- name: Configure
|
||||
id: cmake_configure
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DGGML_VULKAN=ON
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
ubuntu-24-cmake-vulkan:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
|
||||
+161
-133
@@ -3358,7 +3358,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
add_opt(common_arg(
|
||||
{"--chat-template-kwargs"}, "STRING",
|
||||
string_format("sets additional params for the json template parser"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
[](common_params & params, const std::string & value) {
|
||||
auto parsed = json::parse(value);
|
||||
for (const auto & item : parsed.items()) {
|
||||
params.default_template_kwargs[item.key()] = item.value().dump();
|
||||
@@ -3570,21 +3570,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
common_log_set_file(common_log_main(), value.c_str());
|
||||
}
|
||||
));
|
||||
add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params &, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
|
||||
} else if (is_falsey(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
|
||||
} else if (is_autoy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg(
|
||||
{"--log-colors"}, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params &, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
|
||||
} else if (is_falsey(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
|
||||
} else if (is_autoy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg(
|
||||
{"-v", "--verbose", "--log-verbose"},
|
||||
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)",
|
||||
@@ -3850,7 +3852,87 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||
|
||||
// model-specific
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-steps"}, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-visual"},
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)", params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-eps"}, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-algorithm"}, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-alg-temp"}, "F",
|
||||
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-block-length"}, "N",
|
||||
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
|
||||
[](common_params & params, int value) { params.diffusion.block_length = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-cfg-scale"}, "F",
|
||||
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-add-gumbel-noise"}, "F",
|
||||
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "-lr", "--learning-rate" }, "ALPHA",
|
||||
string_format("adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)", (double) params.lr.lr0),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
|
||||
string_format("(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
|
||||
(double) params.lr.lr_min),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-decay-epochs", "--learning-rate-decay-epochs"}, "ALPHA",
|
||||
string_format("(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)", (double) params.lr.decay_epochs),
|
||||
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-wd", "--weight-decay"}, "WD",
|
||||
string_format("adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).", (double) params.lr.wd),
|
||||
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-val-split", "--val-split"}, "FRACTION",
|
||||
string_format("fraction of data to use as validation set for training (default: %.2g).", (double) params.val_split),
|
||||
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-epochs", "--epochs"}, "N",
|
||||
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
|
||||
[](common_params & params, int epochs) { params.lr.epochs = epochs; }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-opt", "--optimizer"}, "sgd|adamw", "adamw or sgd",
|
||||
[](common_params & params, const std::string & name) {
|
||||
params.optimizer = common_opt_get_optimizer(name.c_str());
|
||||
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
|
||||
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
|
||||
}
|
||||
}
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
|
||||
// presets
|
||||
add_opt(common_arg(
|
||||
{"--tts-oute-default"},
|
||||
string_format("use default OuteTTS models (note: can download weights from the internet)"),
|
||||
@@ -3863,39 +3945,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-bge-small-en-default"},
|
||||
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
|
||||
{"--embd-gemma-default"},
|
||||
string_format("use default EmbeddingGemma model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
|
||||
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-e5-small-en-default"},
|
||||
string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
|
||||
params.model.hf_file = "e5-small-v2-q8_0.gguf";
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-gte-small-default"},
|
||||
string_format("use default gte-small model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
|
||||
params.model.hf_file = "gte-small-q8_0.gguf";
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.model.hf_repo = "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF";
|
||||
params.model.hf_file = "embeddinggemma-300M-qat-Q4_0.gguf";
|
||||
params.port = 8011;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 2048;
|
||||
params.n_parallel = 32;
|
||||
params.n_ctx = 2048*params.n_parallel;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
@@ -3990,96 +4049,65 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-steps" }, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-visual" },
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
|
||||
params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--gpt-oss-20b-default"},
|
||||
string_format("use gpt-oss-20b (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gpt-oss-20b-GGUF";
|
||||
params.model.hf_file = "gpt-oss-20b-mxfp4.gguf";
|
||||
params.port = 8013;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 32768;
|
||||
params.n_parallel = 2;
|
||||
params.n_ctx = 131072*params.n_parallel;
|
||||
params.sampling.temp = 1.0f;
|
||||
params.sampling.top_p = 1.0f;
|
||||
params.sampling.top_k = 0;
|
||||
params.sampling.min_p = 0.01f;
|
||||
params.use_jinja = true;
|
||||
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-eps" }, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-algorithm" }, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)",
|
||||
params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-alg-temp" }, "F",
|
||||
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--gpt-oss-120b-default"},
|
||||
string_format("use gpt-oss-120b (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gpt-oss-120b-GGUF";
|
||||
params.port = 8013;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 32768;
|
||||
params.n_parallel = 2;
|
||||
params.n_ctx = 131072*params.n_parallel;
|
||||
params.sampling.temp = 1.0f;
|
||||
params.sampling.top_p = 1.0f;
|
||||
params.sampling.top_k = 0;
|
||||
params.sampling.min_p = 0.01f;
|
||||
params.use_jinja = true;
|
||||
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-block-length" }, "N",
|
||||
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
|
||||
[](common_params & params, int value) { params.diffusion.block_length = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-cfg-scale" }, "F",
|
||||
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-add-gumbel-noise" }, "F",
|
||||
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--vision-gemma-4b-default"},
|
||||
string_format("use Gemma 3 4B QAT (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gemma-3-4b-it-qat-GGUF";
|
||||
params.port = 8014;
|
||||
params.n_ctx = 0;
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
|
||||
add_opt(
|
||||
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
|
||||
string_format(
|
||||
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
|
||||
(double) params.lr.lr0),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(
|
||||
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
|
||||
string_format(
|
||||
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
|
||||
(double) params.lr.lr_min),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(
|
||||
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
|
||||
string_format(
|
||||
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
|
||||
(double) params.lr.decay_epochs),
|
||||
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{ "-wd", "--weight-decay" }, "WD",
|
||||
string_format(
|
||||
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
|
||||
(double) params.lr.wd),
|
||||
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
|
||||
string_format("fraction of data to use as validation set for training (default: %.2g).",
|
||||
(double) params.val_split),
|
||||
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
|
||||
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
|
||||
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
|
||||
[](common_params & params, const std::string & name) {
|
||||
params.optimizer = common_opt_get_optimizer(name.c_str());
|
||||
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
|
||||
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
|
||||
}
|
||||
})
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
{"--vision-gemma-12b-default"},
|
||||
string_format("use Gemma 3 12B QAT (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gemma-3-12b-it-qat-GGUF";
|
||||
params.port = 8014;
|
||||
params.n_ctx = 0;
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
if (is_arguments_path({})) {
|
||||
// Entire JSON is the arguments and was parsed fully.
|
||||
return consume_json_result {
|
||||
partial->json.dump(),
|
||||
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
@@ -444,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
std::vector<std::string> path;
|
||||
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||
if (is_arguments_path(path)) {
|
||||
auto arguments = j.dump();
|
||||
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
|
||||
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||
if (idx != std::string::npos) {
|
||||
|
||||
+1
-1
@@ -426,7 +426,7 @@ struct common_params {
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||
int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc.
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
@@ -168,6 +169,47 @@ bool common_json_parse(
|
||||
}
|
||||
}
|
||||
|
||||
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
||||
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
||||
|
||||
auto is_high_surrogate = [&](const std::string & s) {
|
||||
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
||||
return s.length() >= 4 &&
|
||||
s[0] == '\\' && s[1] == 'u' &&
|
||||
std::tolower(s[2]) == 'd' &&
|
||||
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
||||
};
|
||||
|
||||
// Initialize the unicode marker to a low surrogate to handle the edge case
|
||||
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
||||
// backslash (\)
|
||||
std::string unicode_marker_padding = "udc00";
|
||||
std::smatch last_unicode_seq;
|
||||
|
||||
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
||||
std::smatch second_last_seq;
|
||||
std::string prelude = str.substr(0, last_unicode_seq.position());
|
||||
|
||||
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
||||
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
||||
|
||||
if (is_high_surrogate(last_unicode_seq.str())) {
|
||||
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
||||
unicode_marker_padding += "\\udc00";
|
||||
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
||||
if (is_high_surrogate(second_last_seq.str())) {
|
||||
// If this follows a high surrogate, pad it to be a low surrogate
|
||||
if (last_unicode_seq.length() == 2) {
|
||||
unicode_marker_padding = "dc00";
|
||||
} else if (last_unicode_seq.length() == 3) {
|
||||
unicode_marker_padding = "c00";
|
||||
} else {
|
||||
// The original unicode_marker_padding is already padded with 0s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
@@ -186,6 +228,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an object value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
@@ -205,6 +250,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an array value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
@@ -230,6 +278,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
||||
// Was inside an object key string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
|
||||
+2
-10
@@ -5966,20 +5966,12 @@ class Mamba2Model(TextModel):
|
||||
class JambaModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||
|
||||
def get_vocab_base_pre(self, tokenizer) -> str:
|
||||
del tokenizer # unused
|
||||
|
||||
return "gpt-2"
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.model").is_file():
|
||||
# Using Jamba's tokenizer.json causes errors on model load
|
||||
# (something about "byte not found in vocab"),
|
||||
# but there's a working tokenizer.model
|
||||
self._set_vocab_sentencepiece()
|
||||
else:
|
||||
# Some Jamba models only have a tokenizer.json, which works.
|
||||
self._set_vocab_gpt2()
|
||||
self._set_vocab_llama_hf()
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
|
||||
|
||||
+10
-8
@@ -31,7 +31,7 @@ Legend:
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -51,7 +51,7 @@ Legend:
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
@@ -65,11 +65,11 @@ Legend:
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
@@ -92,9 +92,9 @@ Legend:
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -102,9 +102,11 @@ Legend:
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
+12095
-4249
File diff suppressed because it is too large
Load Diff
@@ -145,6 +145,9 @@ endif()
|
||||
# which was introduced in POSIX.1-2008, forcing us to go higher
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
||||
add_compile_definitions(_XOPEN_SOURCE=700)
|
||||
elseif (CMAKE_SYSTEM_NAME MATCHES "AIX")
|
||||
# Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default,
|
||||
# in order to define _SC_PHYS_PAGES.
|
||||
else()
|
||||
add_compile_definitions(_XOPEN_SOURCE=600)
|
||||
endif()
|
||||
|
||||
+100
-107
@@ -146,9 +146,7 @@ void ggml_cann_op_unary_gated(
|
||||
unary_op(ctx, acl_src0, acl_dst);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_dst);
|
||||
if(src1)
|
||||
ggml_cann_release_resources(ctx, acl_src1);
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -894,14 +892,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get or expand a cached float32 tensor filled with a scalar value.
|
||||
* @brief Get or expand a cached tensor filled with a scalar value.
|
||||
*
|
||||
* This function manages cached device memory for float32 tensors. If the current
|
||||
* This function manages cached device memory for tensors. If the current
|
||||
* cache size is insufficient for the requested tensor shape, the old memory will
|
||||
* be released and new memory will be allocated. The allocated buffer is then
|
||||
* initialized either with zeros (when @p value == 0.0f) or with the given scalar
|
||||
* value using CANN operations. Finally, an aclTensor object is created from the
|
||||
* cached memory and returned.
|
||||
* be released and new memory will be allocated. The allocated buffer is
|
||||
* initialized with the given scalar value using CANN operations.
|
||||
* Finally, an aclTensor object is created from the cached memory and returned.
|
||||
*
|
||||
* @param ctx The CANN backend context that manages device memory.
|
||||
* @param buffer A pointer to the cached device buffer (will be allocated
|
||||
@@ -910,17 +907,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
||||
* updated when the cache is expanded.
|
||||
* @param ne The tensor shape array (number of elements in each dimension).
|
||||
* @param nb The stride size for each dimension.
|
||||
* @param dtype Data type of cached tensor.
|
||||
* @param dims The number of tensor dimensions.
|
||||
* @param value The scalar value used to fill the tensor (supports zero
|
||||
* initialization via memset or arbitrary values via fill_scalar).
|
||||
* @return An aclTensor pointer created from the cached buffer.
|
||||
*/
|
||||
static aclTensor* get_f32_cache_acl_tensor(
|
||||
static aclTensor* get_cache_acl_tensor(
|
||||
ggml_backend_cann_context& ctx,
|
||||
void** buffer,
|
||||
int64_t &cache_element,
|
||||
int64_t* ne,
|
||||
size_t* nb,
|
||||
ggml_type dtype,
|
||||
int64_t dims,
|
||||
float value) {
|
||||
// Calculate total number of elements
|
||||
@@ -928,7 +927,7 @@ static aclTensor* get_f32_cache_acl_tensor(
|
||||
for (int i = 0; i < dims; i++) {
|
||||
n_element *= ne[i];
|
||||
}
|
||||
size_t size = n_element * sizeof(float);
|
||||
size_t size = n_element * ggml_type_size(dtype);
|
||||
|
||||
// Allocate or expand cache if needed
|
||||
if (cache_element < n_element) {
|
||||
@@ -941,19 +940,17 @@ static aclTensor* get_f32_cache_acl_tensor(
|
||||
cache_element = n_element;
|
||||
|
||||
// Initialize cache
|
||||
if (value == 0.0f) {
|
||||
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
|
||||
} else {
|
||||
int64_t pool_ne[1] = { n_element };
|
||||
size_t pool_nb[1] = { sizeof(float) };
|
||||
aclTensor* acl_value = ggml_cann_create_tensor(
|
||||
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
|
||||
aclnn_fill_scalar(ctx, 1, acl_value);
|
||||
ggml_cann_release_resources(ctx, acl_value);
|
||||
}
|
||||
int64_t pool_ne[1] = { n_element };
|
||||
size_t pool_nb[1] = { ggml_type_size(dtype) };
|
||||
aclTensor* acl_value = ggml_cann_create_tensor(
|
||||
*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype),
|
||||
pool_ne, pool_nb, 1);
|
||||
aclnn_fill_scalar(ctx, value, acl_value);
|
||||
ggml_cann_release_resources(ctx, acl_value);
|
||||
}
|
||||
|
||||
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
|
||||
return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype),
|
||||
ggml_type_size(dtype), ne, nb, dims);
|
||||
}
|
||||
|
||||
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
@@ -965,35 +962,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
// build gamma, one...
|
||||
// build gamma.
|
||||
size_t acl_gamma_nb[GGML_MAX_DIMS];
|
||||
acl_gamma_nb[0] = sizeof(float);
|
||||
// gamma's type is the same with dst.
|
||||
acl_gamma_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
|
||||
aclTensor* acl_gamma = get_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.rms_norm_one_tensor_cache.cache,
|
||||
ctx.rms_norm_one_tensor_cache.size,
|
||||
src->ne,
|
||||
acl_gamma_nb,
|
||||
dst->type,
|
||||
1, // dims
|
||||
1.0f // value
|
||||
);
|
||||
|
||||
// build rstd, zero...
|
||||
// build rstd.
|
||||
int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
|
||||
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
|
||||
// rstd will always be F32.
|
||||
acl_rstd_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
|
||||
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
|
||||
aclTensor* acl_rstd = get_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.rms_norm_zero_tensor_cache.cache,
|
||||
ctx.rms_norm_zero_tensor_cache.size,
|
||||
acl_rstd_ne,
|
||||
acl_rstd_nb,
|
||||
GGML_TYPE_F32,
|
||||
GGML_MAX_DIMS - 1,
|
||||
0.0f // value
|
||||
);
|
||||
@@ -1765,33 +1766,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src0 = dst->src[0]; // src
|
||||
ggml_tensor* src1 = dst->src[1]; // index
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_F16: {
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * sizeof(float));
|
||||
void* src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
if(src0->type == dst->type) {
|
||||
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
} else {
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
|
||||
void* src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = dst->nb[0];
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
|
||||
}
|
||||
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
|
||||
src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_Q8_0: {
|
||||
// add 1 dim for bcast mul.
|
||||
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1],
|
||||
@@ -1799,7 +1802,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1],
|
||||
*dequant_ne;
|
||||
int64_t scale_offset = 0;
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
weight_ne[0] = QK8_0;
|
||||
weight_ne[1] = src0->ne[0] / QK8_0;
|
||||
@@ -1809,7 +1811,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
weight_ne[i] = src0->ne[i - 1];
|
||||
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
|
||||
}
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,1]
|
||||
scale_ne[0] = 1;
|
||||
scale_ne[1] = src0->ne[0] / QK8_0;
|
||||
@@ -1819,18 +1820,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
scale_ne[i] = src0->ne[i - 1];
|
||||
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
|
||||
}
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
dequant_ne = weight_ne;
|
||||
dequant_nb[0] = sizeof(float);
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
|
||||
}
|
||||
|
||||
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
|
||||
ggml_cann_pool_alloc dequant_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * sizeof(float));
|
||||
|
||||
ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type));
|
||||
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
|
||||
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
|
||||
GGML_MAX_DIMS + 1);
|
||||
@@ -1838,22 +1836,20 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
|
||||
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
|
||||
aclTensor* dequant_tensor = ggml_cann_create_tensor(
|
||||
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
|
||||
dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
|
||||
|
||||
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
|
||||
dequant_nb[0] = sizeof(float);
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
dequant_ne = src0->ne;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
|
||||
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(),
|
||||
dequant_ne, dequant_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
|
||||
ggml_cann_release_resources(ctx, dequant_tensor);
|
||||
ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -1965,16 +1961,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (weight_to_nz && is_matmul_weight(weight)) {
|
||||
int64_t acl_stride[2] = {1, transpose_ne[1]};
|
||||
|
||||
// Reverse ne.
|
||||
std::reverse(transpose_ne, transpose_ne + n_dims);
|
||||
|
||||
std::vector<int64_t> storageDims = {transpose_ne[0], transpose_ne[1]};
|
||||
|
||||
acl_weight_tensor = aclCreateTensor(
|
||||
transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
|
||||
0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
|
||||
acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
|
||||
} else {
|
||||
acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
|
||||
@@ -3178,7 +3166,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
aclTensor* acl_src0_f16_tensor = nullptr;
|
||||
aclTensor* acl_src1_f16_tensor = nullptr;
|
||||
aclTensor* acl_src2_f16_tensor = nullptr;
|
||||
aclTensor* acl_dst_f16_tensor = nullptr;
|
||||
|
||||
// Step 1: cast the src0 (Query) to fp16 if needed
|
||||
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
|
||||
@@ -3216,22 +3203,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
|
||||
src2_bsnd_nb, GGML_MAX_DIMS);
|
||||
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
void* out_f16_buffer = out_f16_allocator.alloc(
|
||||
ggml_nelements(dst) * faElemSize);
|
||||
|
||||
int64_t* out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
||||
}
|
||||
|
||||
acl_dst_f16_tensor = ggml_cann_create_tensor(
|
||||
out_f16_buffer, faDataType, faElemSize,
|
||||
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
|
||||
);
|
||||
|
||||
// Step 3: create the PSEShift tensor if needed
|
||||
// this tensor is considered as mask (f16) in the llama.cpp
|
||||
aclTensor* bcast_pse_tensor = nullptr;
|
||||
@@ -3317,8 +3288,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
|
||||
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
|
||||
aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
|
||||
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
|
||||
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
|
||||
aclTensorList* acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
|
||||
aclTensorList* acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
|
||||
|
||||
int64_t numHeads = src0->ne[2]; // N
|
||||
int64_t numKeyValueHeads = src1->ne[2];
|
||||
@@ -3334,8 +3305,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
int64_t keyAntiquantMode = 0;
|
||||
int64_t valueAntiquantMode = 0;
|
||||
|
||||
// Step 5: launch the FusedInferAttentionScoreV2 kernel.
|
||||
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
aclTensor * fa_dst_tensor = nullptr;
|
||||
aclTensor * acl_dst_tensor = nullptr;
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
void* out_f16_buffer = out_f16_allocator.alloc(
|
||||
ggml_nelements(dst) * faElemSize);
|
||||
|
||||
int64_t* out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
||||
}
|
||||
|
||||
fa_dst_tensor = ggml_cann_create_tensor(
|
||||
out_f16_buffer, faDataType, faElemSize,
|
||||
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
|
||||
);
|
||||
}
|
||||
else {
|
||||
fa_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
}
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
|
||||
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
|
||||
@@ -3357,23 +3349,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
blockSize, antiquantMode, // blockSize, antiquantMode
|
||||
softmaxLseFlag, // softmaxLseFlag
|
||||
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
|
||||
acl_dst_f16_tensor, // attentionOut
|
||||
fa_dst_tensor, // attentionOut
|
||||
nullptr // softmaxLse
|
||||
);
|
||||
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
// TODO: when dst is fp16, don't need cast
|
||||
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
||||
acl_src1_f16_tensor,
|
||||
acl_src2_f16_tensor,
|
||||
acl_dst_f16_tensor,
|
||||
acl_dst_tensor);
|
||||
if(src3 != nullptr){
|
||||
ggml_cann_release_resources(ctx, bcast_pse_tensor);
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
}
|
||||
}else{
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
||||
acl_k_tensor_list,
|
||||
acl_v_tensor_list,
|
||||
fa_dst_tensor,
|
||||
acl_dst_tensor,
|
||||
bcast_pse_tensor);
|
||||
|
||||
} else {
|
||||
GGML_ABORT("Function is not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ struct ggml_compute_params {
|
||||
#endif // __VXE2__
|
||||
#endif // __s390x__ && __VEC__
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
|
||||
#include <sys/prctl.h>
|
||||
#endif
|
||||
|
||||
|
||||
@@ -689,8 +689,13 @@ bool ggml_is_numa(void) {
|
||||
#endif
|
||||
|
||||
static void ggml_init_arm_arch_features(void) {
|
||||
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__linux__)
|
||||
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
|
||||
#else
|
||||
// TODO: add support of SVE for non-linux systems
|
||||
#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -463,9 +463,9 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
float val = x[i] - mean;
|
||||
y[i] = val;
|
||||
val *= val;
|
||||
sum += (ggml_float)val;
|
||||
y[i] = val;
|
||||
}
|
||||
return sum/n;
|
||||
}
|
||||
|
||||
@@ -144,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
|
||||
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
|
||||
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
|
||||
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
|
||||
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
|
||||
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
|
||||
@@ -160,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
|
||||
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
|
||||
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
@@ -820,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f
|
||||
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
|
||||
inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
|
||||
const float v = GGML_CPU_FP16_TO_FP32(x[i]);
|
||||
y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
|
||||
}
|
||||
}
|
||||
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
||||
|
||||
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
|
||||
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
||||
|
||||
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||
|
||||
@@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
|
||||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
|
||||
return GGML_CUDA_CC_IS_AMD(cc) ||
|
||||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
@@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
|
||||
}
|
||||
|
||||
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
|
||||
// Important: do not use this function if dst and src both point at registers.
|
||||
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
|
||||
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
|
||||
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
|
||||
template <int nbytes, int alignment = 0>
|
||||
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
||||
if constexpr (alignment != 0) {
|
||||
@@ -939,13 +944,6 @@ struct ggml_cuda_graph {
|
||||
bool disable_due_to_failed_graph_capture = false;
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
bool use_cpy_indirection = false;
|
||||
std::vector<char *> cpy_dest_ptrs;
|
||||
char ** dest_ptrs_d;
|
||||
int dest_ptrs_size = 0;
|
||||
// Index to allow each cpy kernel to be aware of it's position within the graph
|
||||
// relative to other cpy nodes.
|
||||
int graph_cpynode_index = -1;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
+55
-163
@@ -8,18 +8,16 @@
|
||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int nb12, const int nb13) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
@@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int nb12, const int nb13) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
@@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int nb12, const int nb13) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
@@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
|
||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
// Copy destination pointers to GPU to be available when pointer indirection is in use
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
if (cuda_graph->dest_ptrs_d != nullptr) {
|
||||
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
|
||||
}
|
||||
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
|
||||
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
|
||||
}
|
||||
// copy destination pointers to GPU
|
||||
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
|
||||
cuda_graph->graph_cpynode_index = 0; // reset index
|
||||
#else
|
||||
GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q8_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int num_blocks = ne / QK8_0;
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q8_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_0 == 0);
|
||||
const int num_blocks = ne / QK4_0;
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_0_f32_cuda(
|
||||
@@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda(
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_1 == 0);
|
||||
const int num_blocks = ne / QK4_1;
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_1_f32_cuda(
|
||||
@@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda(
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_0 == 0);
|
||||
const int num_blocks = ne / QK5_0;
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_0_f32_cuda(
|
||||
@@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda(
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_1 == 0);
|
||||
const int num_blocks = ne / QK5_1;
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_1_f32_cuda(
|
||||
@@ -259,25 +233,25 @@ static void ggml_cpy_q5_1_f32_cuda(
|
||||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_NL == 0);
|
||||
const int num_blocks = ne / QK4_NL;
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
|
||||
@@ -311,16 +285,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
char * src0_ddc = (char *) src0->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
char ** dest_ptrs_d = nullptr;
|
||||
int graph_cpynode_index = -1;
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
|
||||
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(disable_indirection_for_this_node);
|
||||
#endif
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
@@ -329,134 +293,62 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
} else
|
||||
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
||||
{
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(disable_indirection_for_this_node);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
bool disable_indirection = true;
|
||||
ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
|
||||
}
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
// Prioritize CUDA graph compatibility over direct memory copy optimization.
|
||||
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, half>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, half>>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, float>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
ggml_cuda_cpy(ctx, src0, dst);
|
||||
}
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
|
||||
#define CUDA_CPY_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
|
||||
|
||||
@@ -793,8 +793,6 @@ void launch_fattn(
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
@@ -878,7 +876,7 @@ void launch_fattn(
|
||||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
// multiple sequences of possibly different lengths.
|
||||
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
const int s31 = mask->nb[1] / sizeof(half2);
|
||||
const int s33 = mask->nb[3] / sizeof(half2);
|
||||
|
||||
@@ -916,8 +914,7 @@ void launch_fattn(
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
||||
} else {
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
@@ -946,7 +943,7 @@ void launch_fattn(
|
||||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = Q->ne[2]*Q->ne[3];
|
||||
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
|
||||
@@ -1,756 +1,45 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile.cuh"
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
|
||||
// kq_stride == number of KQ rows to process per iteration
|
||||
// kq_nbatch == number of K columns to load in parallel for KQ calculation
|
||||
|
||||
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
if (fast_fp16_available(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
GGML_UNUSED(warp_size);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
#ifdef RDNA
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // RDNA
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
case 256:
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
case 256:
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
return 128;
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED_VARS(cc, ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
|
||||
#ifdef RDNA
|
||||
return 3;
|
||||
#else
|
||||
return ncols <= 16 ? 3 : 2;
|
||||
#endif // RDNA
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
|
||||
static __global__ void flash_attn_tile(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef GGML_USE_WMMA_FATTN
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // GGML_USE_WMMA_FATTN
|
||||
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = 32;
|
||||
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
|
||||
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
|
||||
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
|
||||
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
|
||||
static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
|
||||
|
||||
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
constexpr int cpw = ncols/nwarps; // cols per warp
|
||||
|
||||
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
|
||||
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
|
||||
|
||||
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ half2 Q_tmp[ncols][D/2];
|
||||
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#else
|
||||
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
|
||||
|
||||
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ float Q_tmp[ncols][D];
|
||||
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
|
||||
|
||||
float KQ_max[cpw];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
}
|
||||
float KQ_sum[cpw] = {0.0f};
|
||||
|
||||
// Load Q data, convert to FP16 if fast.
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
const int j = j0 + threadIdx.y*cpw;
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
float tmp_f[cpy_ne_D] = {0.0f};
|
||||
if (ic0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp_f[i1] *= scale;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
||||
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Main loop over KV cache:
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float KQ_max_new[cpw];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < cpw; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
|
||||
|
||||
// KQ = K @ Q matrix multiplication:
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
|
||||
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
half2 tmp_h2[cpy_ne_kqnb/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_kqnb/2];
|
||||
#pragma unroll
|
||||
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
|
||||
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
||||
half2 K_k[kq_stride/warp_size][cpy_ne];
|
||||
half2 Q_k[cpw][cpy_ne];
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
||||
float K_k[kq_stride/warp_size][cpy_ne];
|
||||
float Q_k[cpw][cpy_ne];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < cpy_ne; ++k) {
|
||||
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (k_KQ_0 + kq_nbatch < D) {
|
||||
__syncthreads(); // Sync not needed on last iteration.
|
||||
}
|
||||
}
|
||||
|
||||
// Apply logit softcap, mask, update KQ_max:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (use_logit_softcap) {
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#else
|
||||
float tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
|
||||
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
|
||||
KQ_max[j0+j1] = KQ_max_new[j0+j1];
|
||||
|
||||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/warp_size][j1] = val;
|
||||
}
|
||||
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
|
||||
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
|
||||
}
|
||||
}
|
||||
|
||||
// VKQ = V @ KQ matrix multiplication:
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
|
||||
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
|
||||
const int k_tile = k1 + threadIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
|
||||
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
||||
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
half2 V_k[(D/2)/warp_size];
|
||||
half2 KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
half tmp[softmax_iter_j];
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
|
||||
&tmp, KQ[j][k0 + k1]);
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_k[j0+j1] = __half2half2(tmp[j1]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
|
||||
&KQ_k[j0], KQ[j][k0 + k1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
|
||||
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Attention sink: adjust running max and sum once per head
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
|
||||
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
|
||||
KQ_max[j0] = KQ_max_new_j;
|
||||
|
||||
const float val = expf(sink - KQ_max[j0]);
|
||||
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
KQ_sum[j0] += val;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
if (gridDim.y == 1) {
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
|
||||
}
|
||||
#else
|
||||
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
|
||||
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
// Write back results:
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
float2 tmp[cpy_ne_D];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (D <= 128) {
|
||||
if (Q->ne[1] > 32) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
if (Q->ne[1] > 16) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
}
|
||||
|
||||
template <bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
switch (K->ne[0]) {
|
||||
case 40: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
|
||||
} break;
|
||||
case 64: {
|
||||
launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
|
||||
} break;
|
||||
case 80: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
|
||||
} break;
|
||||
case 96: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
|
||||
} break;
|
||||
case 112: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
|
||||
} break;
|
||||
case 128: {
|
||||
launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
|
||||
} break;
|
||||
case 256: {
|
||||
launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("Unsupported head size");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
|
||||
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = false;
|
||||
constexpr bool need_f16_V = false;
|
||||
const bool need_f16_K = type_K == GGML_TYPE_F16;
|
||||
const bool need_f16_V = type_V == GGML_TYPE_F16;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
@@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
||||
|
||||
+52
-41
@@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
}
|
||||
}
|
||||
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
{ \
|
||||
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
|
||||
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
|
||||
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||
@@ -198,6 +202,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
#endif// FLASH_ATTN_AVAILABLE
|
||||
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
@@ -206,37 +211,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
// The effective batch size for the kernel can be increased by gqa_ratio.
|
||||
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
|
||||
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
// TODO: temporary until support is extended
|
||||
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
|
||||
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
switch (K->ne[0]) {
|
||||
case 40:
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 80:
|
||||
case 96:
|
||||
case 128:
|
||||
case 112:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 576:
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
@@ -251,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
switch (K->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
@@ -270,47 +271,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
|
||||
|
||||
// If Turing tensor cores available, use them except for some cases with batch size 1:
|
||||
if (turing_mma_available(cc)) {
|
||||
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// If Turing tensor cores available, use them:
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
if (Q->ne[1] <= 2) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
|
||||
if (!gqa_opt_applies && Q->ne[1] == 1) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
|
||||
return best;
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
// Use kernels specialized for small batch sizes if possible:
|
||||
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
||||
// For large batch sizes, use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc)) {
|
||||
// Use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
|
||||
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
}
|
||||
|
||||
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (Q->ne[1] == 1) {
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
return BEST_FATTN_KERNEL_TILE;
|
||||
}
|
||||
|
||||
|
||||
@@ -273,6 +273,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
}
|
||||
|
||||
// Temporary performance fix:
|
||||
// Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
|
||||
// TODO: Check for future drivers the default scheduling strategy and
|
||||
// remove this call again when cudaDeviceScheduleSpin is default.
|
||||
if (prop.major == 12 && prop.minor == 1) {
|
||||
CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
|
||||
}
|
||||
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
|
||||
@@ -2633,11 +2642,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
||||
bool use_cuda_graph) {
|
||||
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||
@@ -2688,33 +2696,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
|
||||
// Store the pointers which are updated for each token, such that these can be sent
|
||||
// to the device and accessed using indirection from CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
|
||||
|
||||
// store a pointer to each copy op CUDA kernel to identify it later
|
||||
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
||||
if (!ptr) {
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
cuda_ctx->cuda_graph->use_cpy_indirection = true;
|
||||
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
|
||||
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
|
||||
}
|
||||
|
||||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
@@ -2733,7 +2719,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
|
||||
|
||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
if (node->data != graph_node_properties->node_address &&
|
||||
node->op != GGML_OP_CPY &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
@@ -2754,7 +2739,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i] &&
|
||||
node->src[i]->data != graph_node_properties->src_address[i] &&
|
||||
node->op != GGML_OP_CPY &&
|
||||
node->op != GGML_OP_VIEW
|
||||
) {
|
||||
return false;
|
||||
@@ -2901,7 +2885,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
|
||||
//if rms norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -3120,7 +3104,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
if (use_cuda_graph) {
|
||||
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
|
||||
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
|
||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
||||
|
||||
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
@@ -3147,10 +3131,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
cuda_ctx->cuda_graph->use_cpy_indirection = false;
|
||||
}
|
||||
|
||||
#else
|
||||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
@@ -3867,7 +3847,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "ggml.h"
|
||||
#include "mmf.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
@@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
|
||||
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
|
||||
|
||||
mmf_ids_data ids_info{};
|
||||
mmf_ids_data * ids_info_ptr = nullptr;
|
||||
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
|
||||
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
|
||||
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
|
||||
|
||||
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||||
const int64_t ncols_dst = ids ? ne2 : ne1;
|
||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||
@@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||
nchannels_y = ids->ne[0];
|
||||
}
|
||||
|
||||
if (ids && ncols_dst > 16) {
|
||||
const int64_t n_expert_used = ids->ne[0];
|
||||
const int64_t n_experts = ne02;
|
||||
const int64_t n_tokens = ne12;
|
||||
const int64_t ne_get_rows = n_tokens * n_expert_used;
|
||||
|
||||
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
|
||||
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
|
||||
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
|
||||
|
||||
const int si1 = static_cast<int>(ids_s1);
|
||||
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
|
||||
|
||||
GGML_ASSERT(sis1 > 0);
|
||||
|
||||
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
|
||||
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
ids_info.ids_src_compact = ids_src_compact_dev.get();
|
||||
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
|
||||
ids_info.expert_bounds_dev = expert_bounds_dev.get();
|
||||
ids_info.n_experts = static_cast<int>(n_experts);
|
||||
ids_info.sis1 = sis1;
|
||||
ids_info_ptr = &ids_info;
|
||||
}
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
@@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half2 * src0_d = (const half2 *) src0->data;
|
||||
@@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||
@@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||
@@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
}
|
||||
|
||||
if (mul_mat_id) {
|
||||
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
|
||||
if (src0_ne[1] <= 1024 && src1_ncols > 512) {
|
||||
return false;
|
||||
}
|
||||
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
|
||||
} else if(src0_ne[1] > 1024 && src1_ncols > 128) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
|
||||
+313
-31
@@ -7,6 +7,14 @@ using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
|
||||
struct mmf_ids_data {
|
||||
const int32_t * ids_src_compact = nullptr;
|
||||
const int32_t * ids_dst_compact = nullptr;
|
||||
const int32_t * expert_bounds_dev = nullptr;
|
||||
int n_experts = 0;
|
||||
int sis1 = 0;
|
||||
};
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
|
||||
@@ -224,6 +232,250 @@ static __global__ void mul_mat_f(
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||
static __global__ void mul_mat_f_ids(
|
||||
const T * __restrict__ x, const float * __restrict__ y,
|
||||
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
|
||||
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
|
||||
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
const int row0 = blockIdx.x * rows_per_block;
|
||||
|
||||
const int expert_idx = blockIdx.y;
|
||||
const int expert_start = expert_bounds[expert_idx];
|
||||
const int expert_end = expert_bounds[expert_idx + 1];
|
||||
const int ncols_expert = expert_end - expert_start;
|
||||
|
||||
const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
|
||||
const int tile_idx = blockIdx.z;
|
||||
if (tile_idx >= tiles_for_expert) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int col_base = tile_idx * cols_per_block;
|
||||
|
||||
GGML_UNUSED(channel_ratio);
|
||||
|
||||
const int channel_x = expert_idx;
|
||||
const int sample_dst = 0;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_y = sample_dst;
|
||||
|
||||
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
|
||||
y += int64_t(sample_y) *stride_sample_y;
|
||||
dst += int64_t(sample_dst)*stride_sample_dst;
|
||||
|
||||
const int32_t * ids_src_expert = ids_src_compact + expert_start;
|
||||
const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
char * compute_base = data_mmv;
|
||||
|
||||
//const float2 * y2 = (const float2 *) y;
|
||||
|
||||
tile_C C[ntA][ntB];
|
||||
|
||||
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
|
||||
|
||||
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
||||
tile_A A[ntA][warp_size / tile_A::J];
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_A::I; ++i) {
|
||||
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
||||
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
float vals_buf[2][tile_B::I];
|
||||
auto gather_tile = [&](int tile_idx_local, float *vals) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + tile_idx_local*tile_B::I;
|
||||
const int global_j = col_base + j;
|
||||
float val = 0.0f;
|
||||
if (j < cols_per_block && global_j < ncols_expert) {
|
||||
const int src_entry = ids_src_expert[global_j];
|
||||
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
|
||||
const int token = (int) qrm.x;
|
||||
const int channel = (int) qrm.y;
|
||||
if (token < ncols_dst_total) {
|
||||
val = y[channel*stride_channel_y + token*stride_col_y + col];
|
||||
}
|
||||
}
|
||||
vals[j0] = val;
|
||||
}
|
||||
};
|
||||
|
||||
gather_tile(0, vals_buf[0]);
|
||||
|
||||
int curr_buf = 0;
|
||||
int next_buf = 1;
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
gather_tile(itB + 1, vals_buf[next_buf]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||
tile_B B;
|
||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||
}
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
curr_buf ^= 1;
|
||||
next_buf ^= 1;
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
float2 vals_buf[2][tile_B::I];
|
||||
auto gather_tile = [&](int tile_idx_local, float2 *vals) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + tile_idx_local*tile_B::I;
|
||||
const int global_j = col_base + j;
|
||||
float2 tmp = make_float2(0.0f, 0.0f);
|
||||
if (j < cols_per_block && global_j < ncols_expert) {
|
||||
const int src_entry = ids_src_expert[global_j];
|
||||
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
|
||||
const int token = (int) qrm.x;
|
||||
const int channel = (int) qrm.y;
|
||||
if (token < ncols_dst_total) {
|
||||
tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
|
||||
}
|
||||
}
|
||||
vals[j0] = tmp;
|
||||
}
|
||||
};
|
||||
|
||||
if (ntB > 0) {
|
||||
gather_tile(0, vals_buf[0]);
|
||||
}
|
||||
|
||||
int curr_buf = 0;
|
||||
int next_buf = 1;
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const float2 tmp = vals_buf[curr_buf][j0];
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
gather_tile(itB + 1, vals_buf[next_buf]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||
tile_B B;
|
||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||
}
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
curr_buf ^= 1;
|
||||
next_buf ^= 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
|
||||
const int j = itB*tile_C::J + tile_C::get_j(l);
|
||||
buf_iw[j*kiw + i] = C[itA][itB].x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
}
|
||||
|
||||
const int global_j = col_base + j;
|
||||
if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
|
||||
const int dst_entry = ids_dst_expert[global_j];
|
||||
const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
|
||||
const int token = (int) qrm.x;
|
||||
if (token < ncols_dst_total) {
|
||||
const int slot = (int) qrm.y;
|
||||
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
static inline void mul_mat_f_switch_ids(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
@@ -232,13 +484,35 @@ static inline void mul_mat_f_switch_ids(
|
||||
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
||||
if (ids) {
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
|
||||
const mmf_ids_data * ids_data) {
|
||||
const bool has_ids_data = ids_data && ids_data->ids_src_compact;
|
||||
|
||||
// Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
|
||||
// we prefer the normal mul_mat_f path with has_ids=true.
|
||||
if (has_ids_data && ncols_dst > 16) {
|
||||
const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
|
||||
if (max_tiles == 0) {
|
||||
return;
|
||||
}
|
||||
dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
|
||||
|
||||
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
|
||||
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
|
||||
|
||||
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
|
||||
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
sis1_fd, nch_fd);
|
||||
} else if (ids) {
|
||||
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
|
||||
dim3 block_nums_ids = block_nums;
|
||||
block_nums_ids.y *= col_tiles;
|
||||
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
@@ -258,7 +532,7 @@ void mul_mat_f_cuda(
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
|
||||
@@ -290,7 +564,7 @@ void mul_mat_f_cuda(
|
||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
|
||||
const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
|
||||
|
||||
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
|
||||
const dim3 block_dims(warp_size, nwarps_best, 1);
|
||||
@@ -300,49 +574,57 @@ void mul_mat_f_cuda(
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
|
||||
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
|
||||
|
||||
@@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
||||
cudaStream_t stream);
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data);
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
#include "common.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||
struct mm_ids_helper_store {
|
||||
uint32_t data;
|
||||
|
||||
__device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||
}
|
||||
|
||||
__device__ uint32_t it() const {
|
||||
return data & 0x003FFFFF;
|
||||
}
|
||||
|
||||
__device__ uint32_t iex_used() const {
|
||||
return data >> 22;
|
||||
}
|
||||
};
|
||||
static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
|
||||
|
||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||
// ids_dst describes the same mapping but for the dst tensor.
|
||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||
template <int n_expert_used_template>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mm_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||
const int expert = blockIdx.x;
|
||||
|
||||
extern __shared__ char data_mm_ids_helper[];
|
||||
mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
|
||||
|
||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
// Generic implementation:
|
||||
for (int it = 0; it < n_tokens; ++it) {
|
||||
int iex_used = -1; // The index at which the expert is used, if any.
|
||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||
const int expert_used = ids[it*si1 + iex];
|
||||
nex_prev += expert_used < expert;
|
||||
if (expert_used == expert) {
|
||||
iex_used = iex;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact] = mm_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||
it_compact++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Implementation optimized for specific numbers of experts used:
|
||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||
const int it = it0 + threadIdx.x / neu_padded;
|
||||
|
||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||
ids[it*si1 + iex] : INT_MAX;
|
||||
const int iex_used = expert_used == expert ? iex : -1;
|
||||
nex_prev += expert_used < expert;
|
||||
|
||||
// Whether the threads at this token position have used the expert:
|
||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||
|
||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||
int it_compact_add_lower = 0;
|
||||
#pragma unroll
|
||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
|
||||
it_compact_add_lower += tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||
}
|
||||
}
|
||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||
|
||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||
const mm_ids_helper_store store_it = store[itc];
|
||||
const int it = store_it.it();
|
||||
const int iex_used = store_it.iex_used();
|
||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||
}
|
||||
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[expert] = nex_prev;
|
||||
|
||||
if (expert < static_cast<int>(gridDim.x) - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||
}
|
||||
|
||||
template <int n_expert_used_template>
|
||||
static void launch_mm_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
|
||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
|
||||
|
||||
const dim3 num_blocks(n_experts, 1, 1);
|
||||
const dim3 block_size(warp_size, 1, 1);
|
||||
const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
|
||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||
mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||
}
|
||||
|
||||
void ggml_cuda_launch_mm_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
switch (n_expert_used) {
|
||||
case 2:
|
||||
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
default:
|
||||
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
void ggml_cuda_launch_mm_ids_helper(
|
||||
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
|
||||
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);
|
||||
+3
-166
@@ -1,141 +1,6 @@
|
||||
#include "mmq.cuh"
|
||||
#include "quantize.cuh"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||
struct mmq_ids_helper_store {
|
||||
uint32_t data;
|
||||
|
||||
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||
}
|
||||
|
||||
__device__ uint32_t it() const {
|
||||
return data & 0x003FFFFF;
|
||||
}
|
||||
|
||||
__device__ uint32_t iex_used() const {
|
||||
return data >> 22;
|
||||
}
|
||||
};
|
||||
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
|
||||
|
||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||
// ids_dst describes the same mapping but for the dst tensor.
|
||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||
template <int n_expert_used_template>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||
const int expert = blockIdx.x;
|
||||
|
||||
extern __shared__ char data_mmq_ids_helper[];
|
||||
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
|
||||
|
||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
// Generic implementation:
|
||||
for (int it = 0; it < n_tokens; ++it) {
|
||||
int iex_used = -1; // The index at which the expert is used, if any.
|
||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||
const int expert_used = ids[it*si1 + iex];
|
||||
nex_prev += expert_used < expert;
|
||||
if (expert_used == expert) {
|
||||
iex_used = iex;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||
it_compact++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Implementation optimized for specific numbers of experts used:
|
||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||
const int it = it0 + threadIdx.x / neu_padded;
|
||||
|
||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||
ids[it*si1 + iex] : INT_MAX;
|
||||
const int iex_used = expert_used == expert ? iex : -1;
|
||||
nex_prev += expert_used < expert;
|
||||
|
||||
// Whether the threads at this token position have used the expert:
|
||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||
|
||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||
int it_compact_add_lower = 0;
|
||||
#pragma unroll
|
||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
|
||||
it_compact_add_lower += tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||
}
|
||||
}
|
||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||
|
||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||
const mmq_ids_helper_store store_it = store[itc];
|
||||
const int it = store_it.it();
|
||||
const int iex_used = store_it.iex_used();
|
||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||
}
|
||||
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[expert] = nex_prev;
|
||||
|
||||
if (expert < static_cast<int>(gridDim.x) - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||
}
|
||||
|
||||
template <int n_expert_used_template>
|
||||
static void launch_mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
|
||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
|
||||
|
||||
const dim3 num_blocks(n_experts, 1, 1);
|
||||
const dim3 block_size(warp_size, 1, 1);
|
||||
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
|
||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||
}
|
||||
#include "mmid.cuh"
|
||||
|
||||
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||
switch (args.type_x) {
|
||||
@@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q(
|
||||
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
||||
const int sis1 = nb12 / nb11;
|
||||
|
||||
switch (n_expert_used) {
|
||||
case 2:
|
||||
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
default:
|
||||
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
+44
-28
@@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
|
||||
static __global__ void mul_mat_vec_f(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const int row = blockIdx.x;
|
||||
const int channel_dst = blockIdx.y;
|
||||
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
|
||||
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
|
||||
const int sample_y = sample_dst;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
@@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += tmpx.x*tmpy.x;
|
||||
sumf[j] += tmpx.y*tmpy.y;
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half>) {
|
||||
@@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += tmpx.x * tmpy.x;
|
||||
sumf[j] += tmpx.y * tmpy.y;
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
|
||||
//TODO: add support for ggml_cuda_mad for hip_bfloat162
|
||||
#if defined(GGML_USE_HIP)
|
||||
const int * x2 = (const int *) x;
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const int tmpx = x2[col2];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
||||
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
||||
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
|
||||
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
|
||||
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
|
||||
}
|
||||
}
|
||||
#else
|
||||
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const nv_bfloat162 tmpx = x2[col2];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
@@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
|
||||
GGML_ASSERT(stride_col_y % 2 == 0);
|
||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||||
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
|
||||
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
@@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
|
||||
case 32: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(112, 112);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(128, 128);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(256, 256);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(40, 40);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(576, 512);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(64, 64);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(80, 80);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(96, 96);
|
||||
@@ -3,8 +3,17 @@
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
|
||||
|
||||
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
@@ -51,6 +60,11 @@ def get_short_name(long_quant_name):
|
||||
for filename in glob("*.cu"):
|
||||
os.remove(filename)
|
||||
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type_k in TYPES_KV:
|
||||
for type_v in TYPES_KV:
|
||||
with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
@@ -64,7 +78,9 @@ for ncols in [8, 16, 32, 64]:
|
||||
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
|
||||
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
if head_size_kq == 40:
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 == 16:
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 != 16:
|
||||
|
||||
@@ -53,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
|
||||
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
|
||||
|
||||
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
|
||||
@@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_SUM);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
||||
|
||||
@@ -1482,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_OPT_STEP_SGD);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
@@ -134,6 +135,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
#include <Metal/Metal.h>
|
||||
|
||||
#include <stdatomic.h>
|
||||
|
||||
#ifndef TARGET_OS_VISION
|
||||
#define TARGET_OS_VISION 0
|
||||
#endif
|
||||
@@ -22,6 +24,9 @@
|
||||
// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
||||
static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
|
||||
|
||||
// virtual address for GPU memory allocations
|
||||
static atomic_uintptr_t g_addr_device = 0x000000400ULL;
|
||||
|
||||
#if !GGML_METAL_EMBED_LIBRARY
|
||||
// Here to assist with NSBundle Path Hack
|
||||
@interface GGMLMetalClass : NSObject
|
||||
@@ -656,6 +661,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@@ -692,7 +698,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
// for new head sizes, add checks here
|
||||
if (op->src[0]->ne[0] != 40 &&
|
||||
if (op->src[0]->ne[0] != 32 &&
|
||||
op->src[0]->ne[0] != 40 &&
|
||||
op->src[0]->ne[0] != 64 &&
|
||||
op->src[0]->ne[0] != 80 &&
|
||||
op->src[0]->ne[0] != 96 &&
|
||||
@@ -798,6 +805,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return false;
|
||||
};
|
||||
}
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return has_simdgroup_reduction;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -822,7 +832,7 @@ struct ggml_metal_buffer_wrapper {
|
||||
};
|
||||
|
||||
struct ggml_metal_buffer {
|
||||
void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985
|
||||
void * all_data;
|
||||
size_t all_size;
|
||||
|
||||
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
|
||||
@@ -960,14 +970,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
||||
if (shared) {
|
||||
res->all_data = ggml_metal_host_malloc(size_aligned);
|
||||
res->is_shared = true;
|
||||
res->owned = true;
|
||||
} else {
|
||||
// dummy, non-NULL value - we'll populate this after creating the Metal buffer below
|
||||
res->all_data = (void *) 0x000000400ULL;
|
||||
// use virtual address from g_addr_device counter
|
||||
res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
|
||||
res->is_shared = false;
|
||||
}
|
||||
res->all_size = size_aligned;
|
||||
|
||||
res->owned = true;
|
||||
|
||||
res->device = ggml_metal_device_get_obj(dev);
|
||||
res->queue = ggml_metal_device_get_queue(dev);
|
||||
|
||||
@@ -978,15 +989,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
|
||||
res->buffers[0].metal = nil;
|
||||
|
||||
if (size_aligned > 0) {
|
||||
if (props_dev->use_shared_buffers &&shared) {
|
||||
if (props_dev->use_shared_buffers && shared) {
|
||||
res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
|
||||
length:size_aligned
|
||||
options:MTLResourceStorageModeShared
|
||||
deallocator:nil];
|
||||
} else {
|
||||
res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
|
||||
|
||||
res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1134,7 +1143,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
|
||||
|
||||
void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memset((char *)tensor->data + offset, value, size);
|
||||
memset((char *) tensor->data + offset, value, size);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1163,7 +1172,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
|
||||
|
||||
void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memcpy((char *)tensor->data + offset, data, size);
|
||||
memcpy((char *) tensor->data + offset, data, size);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1218,7 +1227,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
|
||||
|
||||
void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
if (buf->is_shared) {
|
||||
memcpy(data, (const char *)tensor->data + offset, size);
|
||||
memcpy(data, (const char *) tensor->data + offset, size);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -251,6 +251,7 @@ typedef struct {
|
||||
int32_t sect_1;
|
||||
int32_t sect_2;
|
||||
int32_t sect_3;
|
||||
bool src2;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
@@ -544,6 +545,10 @@ typedef struct{
|
||||
float limit;
|
||||
} ggml_metal_kargs_glu;
|
||||
|
||||
typedef struct {
|
||||
uint64_t np;
|
||||
} ggml_metal_kargs_sum;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
@@ -773,4 +778,12 @@ typedef struct {
|
||||
uint64_t nb01;
|
||||
} ggml_metal_kargs_argmax;
|
||||
|
||||
typedef struct {
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_opt_step_adamw;
|
||||
|
||||
typedef struct {
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_opt_step_sgd;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
|
||||
@@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_glu(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SUM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_sum(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
@@ -410,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_argmax(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
{
|
||||
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
{
|
||||
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
||||
@@ -840,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
|
||||
|
||||
ggml_metal_kargs_sum args = {
|
||||
/*.np =*/ n,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -1546,9 +1582,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
!ggml_is_transposed(op->src[1]) &&
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
||||
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
|
||||
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
|
||||
//GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
@@ -2934,6 +2969,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||
/* sect_1 =*/ sect_1,
|
||||
/* sect_2 =*/ sect_2,
|
||||
/* sect_3 =*/ sect_3,
|
||||
/* src2 =*/ op->src[2] != nullptr,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
|
||||
@@ -3402,3 +3438,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_adamw args = {
|
||||
/*.np =*/ np,
|
||||
};
|
||||
|
||||
int ida = 0;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
const int64_t n = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_sgd args = {
|
||||
/*.np =*/ np,
|
||||
};
|
||||
|
||||
int ida = 0;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
const int64_t n = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
@@ -78,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_op_sum_f32(
|
||||
constant ggml_metal_kargs_sum & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||
|
||||
if (tiitg != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
float acc = 0.0f;
|
||||
for (ulong i = 0; i < args.np; ++i) {
|
||||
acc += src0[i];
|
||||
}
|
||||
|
||||
dst[0] = acc;
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
@@ -3730,7 +3748,7 @@ kernel void kernel_rope_norm(
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
@@ -3783,7 +3801,7 @@ kernel void kernel_rope_neox(
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
@@ -3854,7 +3872,7 @@ kernel void kernel_rope_multi(
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
@@ -3921,7 +3939,7 @@ kernel void kernel_rope_vision(
|
||||
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
||||
// end of mrope
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
@@ -5195,8 +5213,30 @@ kernel void kernel_flash_attn_ext(
|
||||
half, half4, simdgroup_half8x8
|
||||
//float, float4, simdgroup_float8x8
|
||||
|
||||
#define FA_TYPES_F32 \
|
||||
half, half4, simdgroup_half8x8, \
|
||||
float, float4x4, simdgroup_float8x8, \
|
||||
float, float4x4, simdgroup_float8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
float, float2, simdgroup_float8x8, \
|
||||
float, float4, simdgroup_float8x8
|
||||
//half, half4, simdgroup_half8x8
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
@@ -5209,6 +5249,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
@@ -5221,6 +5262,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
@@ -5232,6 +5274,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
@@ -5243,6 +5286,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
@@ -5254,6 +5298,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
@@ -5265,6 +5310,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
@@ -5800,77 +5846,103 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
float, float4, \
|
||||
float4
|
||||
|
||||
#define FA_TYPES_F32 \
|
||||
half4, \
|
||||
float4, \
|
||||
float4, \
|
||||
float, \
|
||||
float, float4, \
|
||||
float4
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
@@ -7487,7 +7559,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
||||
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
template<int NR0, typename args_t>
|
||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7500,13 +7572,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK4_NL;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
const int first_row = (r0 * NSG + sgitg) * NR0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
@@ -7517,6 +7588,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int nb = args.ne00/QK4_NL;
|
||||
const int ns01 = args.nb01/args.nb00;
|
||||
|
||||
const short ix = tiisg/2; // 0...15
|
||||
const short it = tiisg%2; // 0 or 1
|
||||
|
||||
@@ -7524,24 +7598,25 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float sumf[nr0]={0.f};
|
||||
float sumf[NR0]={0.f};
|
||||
|
||||
device const float * yb = y + ix * QK4_NL + it * 8;
|
||||
device const float * yb = y + ix*QK4_NL + it*8;
|
||||
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
||||
|
||||
float4 qf1, qf2;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 16) {
|
||||
// [TAG_MUL_MV_WEIRD]
|
||||
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0];
|
||||
yl[1] = y4[4];
|
||||
yl[2] = y4[1];
|
||||
yl[3] = y4[5];
|
||||
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
device const block_iq4_nl & xb = x[row*nb + ib];
|
||||
for (short row = 0; row < NR0; row++) {
|
||||
device const block_iq4_nl & xb = x[row*ns01 + ib];
|
||||
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
@@ -7572,7 +7647,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
||||
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
||||
float sum_all = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
@@ -7594,7 +7669,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
||||
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
template<int NR0, typename args_t>
|
||||
void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7607,12 +7682,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK_K;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
const int first_row = (r0 * NSG + sgitg) * NR0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
@@ -7623,6 +7697,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int ns01 = args.nb01/args.nb00;
|
||||
|
||||
const short ix = tiisg/16; // 0 or 1
|
||||
const short it = tiisg%16; // 0...15
|
||||
const short ib = it/2;
|
||||
@@ -7632,7 +7709,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float sumf[nr0]={0.f};
|
||||
float sumf[NR0]={0.f};
|
||||
|
||||
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
||||
|
||||
@@ -7641,15 +7718,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
|
||||
float4 qf1, qf2;
|
||||
|
||||
for (int ibl = ix; ibl < nb; ibl += 2) {
|
||||
// [TAG_MUL_MV_WEIRD]
|
||||
for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0];
|
||||
yl[1] = y4[4];
|
||||
yl[2] = y4[1];
|
||||
yl[3] = y4[5];
|
||||
|
||||
for (short row = 0; row < nr0; ++row) {
|
||||
device const block_iq4_xs & xb = x[row*nb + ibl];
|
||||
for (short row = 0; row < NR0; ++row) {
|
||||
device const block_iq4_xs & xb = x[row*ns01 + ibl];
|
||||
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
@@ -7679,7 +7757,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
||||
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
||||
float sum_all = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
@@ -7701,7 +7779,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
||||
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
template<int NR0, typename args_t>
|
||||
void kernel_mul_mv_mxfp4_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7714,13 +7792,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK_MXFP4;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
const int first_row = (r0 * NSG + sgitg) * NR0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
@@ -7731,6 +7808,9 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
||||
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int nb = args.ne00/QK_MXFP4;
|
||||
const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
|
||||
|
||||
const short ix = tiisg/2; // 0...15
|
||||
const short it = tiisg%2; // 0 or 1
|
||||
|
||||
@@ -7738,20 +7818,22 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float sumf[nr0]={0.f};
|
||||
float sumf[NR0]={0.f};
|
||||
|
||||
device const float * yb = y + ix * QK_MXFP4 + it * 8;
|
||||
device const float * yb = y + ix*QK_MXFP4 + it*8;
|
||||
|
||||
// note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
|
||||
// no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
|
||||
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
|
||||
device const float4 * y4 = (device const float4 *) yb;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 16) {
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0];
|
||||
yl[1] = y4[4];
|
||||
yl[2] = y4[1];
|
||||
yl[3] = y4[5];
|
||||
|
||||
#pragma unroll(nr0)
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
device const block_mxfp4 & xb = x[row*nb + ib];
|
||||
FOR_UNROLL (short row = 0; row < NR0; row++) {
|
||||
device const block_mxfp4 & xb = x[row*ns01 + ib];
|
||||
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
|
||||
|
||||
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
|
||||
@@ -7769,7 +7851,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
||||
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
||||
float sum_all = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
@@ -8744,3 +8826,51 @@ kernel void kernel_pool_2d_avg_f32(
|
||||
|
||||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_adamw_f32(
|
||||
constant ggml_metal_kargs_opt_step_adamw & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device float * g_m,
|
||||
device float * g_v,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float alpha = pars[0];
|
||||
const float beta1 = pars[1];
|
||||
const float beta2 = pars[2];
|
||||
const float eps = pars[3];
|
||||
const float wd = pars[4];
|
||||
const float beta1h = pars[5];
|
||||
const float beta2h = pars[6];
|
||||
|
||||
const float gi = g[gid];
|
||||
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
|
||||
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
|
||||
|
||||
g_m[gid] = gmi;
|
||||
g_v[gid] = gvi;
|
||||
|
||||
const float mh = gmi * beta1h;
|
||||
const float vh = sqrt(gvi * beta2h) + eps;
|
||||
|
||||
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_sgd_f32(
|
||||
constant ggml_metal_kargs_opt_step_sgd & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND)
|
||||
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
|
||||
|
||||
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
|
||||
@@ -2348,8 +2348,13 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false");
|
||||
|
||||
if (opencl_c_version.major >= 3) {
|
||||
// Assume it is not available for 3.0, since it is optional in 3.0.
|
||||
// If compiling against 3.0, then we can query.
|
||||
backend_ctx->non_uniform_workgroups = false;
|
||||
#if CL_TARGET_OPENCL_VERSION >= 300
|
||||
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool),
|
||||
&backend_ctx->non_uniform_workgroups, 0));
|
||||
#endif
|
||||
} else {
|
||||
GGML_ASSERT(opencl_c_version.major == 2);
|
||||
// Non-uniform workgroup sizes is mandatory feature in v2.x.
|
||||
@@ -2681,7 +2686,7 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
||||
|
||||
// if rms_norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] &&
|
||||
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
!ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "concat.hpp"
|
||||
#include "conv.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "count-equal.hpp"
|
||||
#include "cpy.hpp"
|
||||
#include "dequantize.hpp"
|
||||
#include "dmmv.hpp"
|
||||
@@ -28,6 +29,7 @@
|
||||
#include "mmvq.hpp"
|
||||
#include "norm.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "pad.hpp"
|
||||
#include "quantize.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "rope.hpp"
|
||||
|
||||
@@ -303,10 +303,6 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
@@ -332,11 +328,6 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_sub(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_count_equal(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_mul(ctx, dst);
|
||||
|
||||
@@ -16,12 +16,6 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float op_count_equal(const float a, const float b) {
|
||||
return (a == b) ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
static __dpct_inline__ float op_mul(const float a, const float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
@@ -195,7 +195,8 @@ struct optimize_feature {
|
||||
|
||||
struct sycl_device_info {
|
||||
int cc; // compute capability
|
||||
// int nsm; // number of streaming multiprocessors
|
||||
int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
|
||||
// number of compute units on a SYCL device.
|
||||
// size_t smpb; // max. shared memory per block
|
||||
size_t smpbo; // max. shared memory per block (with opt-in)
|
||||
bool vmm; // virtual memory support
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
#include "count-equal.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
template <typename T>
|
||||
static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
|
||||
int64_t *__restrict__ dst, const int64_t dk,
|
||||
const int64_t k) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk;
|
||||
const int64_t i1 = sycl::min(i0 + dk, k);
|
||||
|
||||
int nequal = 0;
|
||||
|
||||
for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) {
|
||||
const T xi = x[i];
|
||||
const T yi = y[i];
|
||||
nequal += xi == yi;
|
||||
}
|
||||
|
||||
nequal = warp_reduce_sum(nequal);
|
||||
|
||||
if (item_ct1.get_local_id(2) != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
|
||||
(int *)dst, nequal);
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_I64);
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
int64_t * dst_d = (int64_t *) dst->data;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
const int id = get_current_device_id();
|
||||
const int nsm = ggml_sycl_info().devices[id].nsm;
|
||||
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
|
||||
const int64_t dne =
|
||||
GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE);
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst))));
|
||||
|
||||
const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dpct::dim3 block_nums(
|
||||
std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) /
|
||||
SYCL_COUNT_EQUAL_CHUNK_SIZE),
|
||||
1, 1);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_I32: {
|
||||
const int *src0_d = (const int *)src0->data;
|
||||
const int *src1_d = (const int *)src1->data;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
count_equal(src0_d, src1_d, dst_d, dne, ne);
|
||||
GGML_UNUSED(item_ct1);
|
||||
});
|
||||
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
#ifndef GGML_SYCL_COUNT_EQUAL_HPP
|
||||
#define GGML_SYCL_COUNT_EQUAL_HPP
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif //GGML_SYCL_COUNT_EQUAL_HPP
|
||||
@@ -328,26 +328,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
||||
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
|
||||
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
||||
item_ct1.get_group(0) * ne00 * ne01;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
dst[offset_dst] = static_cast<T>(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
|
||||
const sycl::nd_item<1> &item_ct1) {
|
||||
@@ -431,18 +411,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void pad_sycl(const T *x, T *dst, const int ne00,
|
||||
const int ne01, const int ne02, const int ne0,
|
||||
const int ne1, const int ne2, queue_ptr stream) {
|
||||
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
|
||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
@@ -596,40 +564,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
||||
}
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::half>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
||||
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
||||
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ggml_sycl_detail
|
||||
|
||||
@@ -919,14 +853,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
|
||||
queue_ptr stream) {
|
||||
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
float min_val;
|
||||
float max_val;
|
||||
@@ -1119,10 +1045,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_upscale(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_pad(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
|
||||
@@ -67,8 +67,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -85,9 +85,11 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
||||
|
||||
info.devices[i].cc =
|
||||
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
||||
info.devices[i].nsm = prop.get_max_compute_units();
|
||||
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
||||
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
||||
info.devices[i].smpbo = prop.get_local_mem_size();
|
||||
|
||||
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
||||
}
|
||||
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
@@ -1512,60 +1514,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
|
||||
template <ggml_sort_order order>
|
||||
__dpct_inline__ static void
|
||||
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
|
||||
const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
|
||||
const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
|
||||
uint8_t *dpct_local) {
|
||||
// bitonic sort
|
||||
int col = item_ct1.get_local_id(2);
|
||||
int col_index = item_ct1.get_local_id(2);
|
||||
int row = item_ct1.get_group(1);
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
for (int i = 0; i < tasks_per_thread; i++) {
|
||||
int col = col_index * tasks_per_thread + i;
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const float * x_row = x + row * ncols;
|
||||
auto dst_row = (int *)dpct_local;
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
for (int i=0;i<tasks_per_thread;i++){
|
||||
int col = col_index*tasks_per_thread+i;
|
||||
dst_row[col] = col;
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
for (int i = 0; i < tasks_per_thread; i++) {
|
||||
int col = col_index * tasks_per_thread + i;
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols &&
|
||||
(order == GGML_SORT_ORDER_ASC
|
||||
? x_row[dst_row[col]] > x_row[dst_row[ixj]]
|
||||
: x_row[dst_row[col]] <
|
||||
x_row[dst_row[ixj]]))) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols &&
|
||||
(order == GGML_SORT_ORDER_ASC
|
||||
? x_row[dst_row[col]] < x_row[dst_row[ixj]]
|
||||
: x_row[dst_row[col]] >
|
||||
x_row[dst_row[ixj]]))) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
/*
|
||||
DPCT1118:1: SYCL group functions and algorithms must be encountered
|
||||
in converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
for (int i = 0; i < tasks_per_thread; i++) {
|
||||
int col = col_index * tasks_per_thread + i;
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
@@ -1738,11 +1750,20 @@ static int next_power_of_2(int x) {
|
||||
|
||||
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
const int nrows, ggml_sort_order order,
|
||||
queue_ptr stream) {
|
||||
queue_ptr stream, int device) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, ncols_pad);
|
||||
int nth = 1;
|
||||
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
while (nth < ncols_pad && nth < max_block_size)
|
||||
nth *= 2;
|
||||
if (nth > max_block_size)
|
||||
nth = max_block_size;
|
||||
|
||||
const int tasks_per_thread = ncols_pad / nth;
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, nth);
|
||||
const sycl::range<3> block_nums(1, nrows, 1);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
|
||||
@@ -1755,8 +1776,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
||||
x, dst, ncols, ncols_pad, item_ct1,
|
||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
||||
dpct_local_acc_ct1
|
||||
.get_multi_ptr<sycl::access::decorated::no>()
|
||||
.get());
|
||||
});
|
||||
});
|
||||
@@ -1769,8 +1791,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
||||
x, dst, ncols, ncols_pad, item_ct1,
|
||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
||||
dpct_local_acc_ct1
|
||||
.get_multi_ptr<sycl::access::decorated::no>()
|
||||
.get());
|
||||
});
|
||||
});
|
||||
@@ -2142,7 +2165,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||
|
||||
argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
|
||||
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
|
||||
main_stream, ctx.device);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
@@ -4413,8 +4437,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_PAD:
|
||||
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
//#include "common.hpp"
|
||||
#include "pad.hpp"
|
||||
|
||||
static void pad_f32(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
int i0 = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
int i1 = item_ct1.get_group(1);
|
||||
int i2 = item_ct1.get_group(0) % ne2;
|
||||
int i3 = item_ct1.get_group(0) / ne2;
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +
|
||||
i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static void pad_f32_sycl(const float *src, float *dst, const int lp0,
|
||||
const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3,
|
||||
const int rp3, const int ne0, const int ne1,
|
||||
const int ne2, const int ne3,
|
||||
dpct::queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
||||
dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
|
||||
ne2, ne3);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
||||
|
||||
pad_f32_sycl(src0_d, dst_d,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
}
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_pad(ctx, dst);
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_PAD_HPP
|
||||
#define GGML_SYCL_PAD_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_PAD_BLOCK_SIZE 256
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_PAD_HPP
|
||||
@@ -1,9 +1,18 @@
|
||||
cmake_minimum_required(VERSION 3.19)
|
||||
cmake_policy(SET CMP0114 NEW)
|
||||
cmake_policy(SET CMP0116 NEW)
|
||||
if (POLICY CMP0147)
|
||||
# Parallel build custom build steps
|
||||
cmake_policy(SET CMP0147 NEW)
|
||||
endif()
|
||||
|
||||
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
||||
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
|
||||
# Parallel build object files
|
||||
add_definitions(/MP)
|
||||
endif()
|
||||
|
||||
function(detect_host_compiler)
|
||||
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
||||
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
@@ -2649,11 +2649,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
} \
|
||||
}
|
||||
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
||||
@@ -2661,6 +2663,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
#endif
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
||||
@@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
||||
|
||||
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
|
||||
if (k->type == GGML_TYPE_F32) {
|
||||
k_stride /= 4;
|
||||
}
|
||||
if (v->type == GGML_TYPE_F32) {
|
||||
v_stride /= 4;
|
||||
}
|
||||
|
||||
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
|
||||
bool aligned = (KV % alignment) == 0 &&
|
||||
@@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
}
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
// supported in scalar and coopmat2 paths
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
|
||||
vec4 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const vec4 v = bl.block;
|
||||
const uint idx = coordInBlock[1];
|
||||
const f16vec4 vf16 = f16vec4(v);
|
||||
return vf16[idx];
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
|
||||
block_q4_0_packed16 block;
|
||||
};
|
||||
@@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||
#define dequantFuncA dequantFuncIQ4_NL
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define dequantFuncA dequantFuncMXFP4
|
||||
#elif defined(DATA_A_F32)
|
||||
#define dequantFuncA dequantFuncF32
|
||||
#endif
|
||||
|
||||
@@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
|
||||
|
||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
|
||||
#elif defined(A_TYPE_PACKED16)
|
||||
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
#undef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 4
|
||||
#define BLOCK_BYTE_SIZE 16
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
// iqs is currently always zero in the flash attention shaders
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
return k_packed.k_data_packed[a_offset + ib];
|
||||
} else {
|
||||
return v_packed.v_data_packed[a_offset + ib];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
|
||||
@@ -313,12 +313,12 @@ void main() {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
|
||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC2 cache_b[TN];
|
||||
FLOAT_TYPE_VEC2 cache_b;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||
sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -360,20 +360,22 @@ void main() {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
|
||||
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -388,8 +390,9 @@ void main() {
|
||||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
|
||||
sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -463,14 +466,21 @@ void main() {
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
|
||||
#ifdef MUL_MAT_ID
|
||||
if (dr_warp + cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + 2 * cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
if (dr_warp + 2 * cr + 1 < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
|
||||
}
|
||||
#else
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
|
||||
}
|
||||
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
}
|
||||
|
||||
@@ -611,9 +611,6 @@ void process_shaders() {
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "f32") {
|
||||
continue;
|
||||
}
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
@@ -630,7 +627,7 @@ void process_shaders() {
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
@@ -639,7 +636,7 @@ void process_shaders() {
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <map>
|
||||
|
||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize
|
||||
{ LLM_ARCH_LLAMA, "llama" },
|
||||
{ LLM_ARCH_LLAMA4, "llama4" },
|
||||
{ LLM_ARCH_DECI, "deci" },
|
||||
@@ -275,6 +276,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
};
|
||||
|
||||
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
|
||||
{
|
||||
LLM_ARCH_CLIP,
|
||||
{},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA,
|
||||
{
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
//
|
||||
|
||||
enum llm_arch {
|
||||
LLM_ARCH_CLIP,
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_LLAMA4,
|
||||
LLM_ARCH_DECI,
|
||||
|
||||
+74
-43
@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
||||
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
|
||||
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
|
||||
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
|
||||
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
|
||||
const char * swa_type_str = "unknown";
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
|
||||
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
|
||||
};
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
||||
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
||||
@@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
const int64_t n_kv = ubatch->n_tokens;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(kq_mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
||||
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
const llama_pos p1 = ubatch->pos[i1];
|
||||
|
||||
float * data = (float *) kq_mask->data;
|
||||
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
|
||||
|
||||
// [TAG_NO_CACHE_ISWA]
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
|
||||
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
||||
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
||||
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
||||
const llama_pos p0 = ubatch->pos[i0];
|
||||
|
||||
// mask different sequences
|
||||
if (s0 != s1) {
|
||||
continue; // skip different sequences
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
|
||||
continue; // skip future tokens for causal attention
|
||||
// mask future tokens
|
||||
if (cparams.causal_attn && p0 > p1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
|
||||
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
|
||||
// continue; // skip masked tokens for SWA
|
||||
//}
|
||||
|
||||
// TODO: reimplement this like in llama_kv_cache_unified
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
// apply SWA if any
|
||||
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
||||
}
|
||||
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
GGML_ASSERT(self_kq_mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
|
||||
float * data = (float *) self_kq_mask->data;
|
||||
|
||||
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
|
||||
|
||||
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
|
||||
}
|
||||
}
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(self_kq_mask_swa);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
||||
|
||||
float * data = (float *) self_kq_mask_swa->data;
|
||||
|
||||
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
|
||||
|
||||
fill_mask(data, hparams.n_swa, hparams.swa_type);
|
||||
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1299,12 +1321,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
|
||||
const auto n_kv = k->ne[1];
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
// TODO: replace hardcoded padding with ggml-provided padding
|
||||
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
|
||||
if (cparams.flash_attn && kq_b == nullptr) {
|
||||
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
||||
|
||||
if (v_trans) {
|
||||
@@ -1419,10 +1438,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
||||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
||||
|
||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->kq_mask);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
} else {
|
||||
inp->self_kq_mask_swa = nullptr;
|
||||
inp->self_kq_mask_swa_cnv = nullptr;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
|
||||
}
|
||||
@@ -1447,7 +1476,9 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
// [TAG_NO_CACHE_PAD]
|
||||
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
||||
|
||||
+7
-3
@@ -257,10 +257,14 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
|
||||
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
|
||||
// n_tokens == n_batch
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
@@ -140,7 +140,11 @@ uint32_t llama_hparams::n_embd_s() const {
|
||||
}
|
||||
|
||||
bool llama_hparams::is_recurrent(uint32_t il) const {
|
||||
return recurrent_layer_arr[il];
|
||||
if (il < n_layer) {
|
||||
return recurrent_layer_arr[il];
|
||||
}
|
||||
|
||||
GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer);
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_pos_per_embd() const {
|
||||
|
||||
+12
-11
@@ -478,7 +478,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_GENERAL_NAME, name, false);
|
||||
|
||||
// everything past this point is not vocab-related
|
||||
if (hparams.vocab_only) {
|
||||
// for CLIP models, we only need to load tensors, no hparams
|
||||
if (hparams.vocab_only || ml.get_arch() == LLM_ARCH_CLIP) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -11358,8 +11359,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_gemma_embedding_iswa : public llm_graph_context {
|
||||
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
struct llm_build_gemma_embedding : public llm_graph_context {
|
||||
llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
|
||||
ggml_tensor * cur;
|
||||
@@ -11376,8 +11377,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
auto * inp_attn = build_attn_inp_no_cache();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -16313,10 +16313,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
|
||||
}
|
||||
|
||||
ggml_tensor * build_layer_ffn(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inpSA,
|
||||
const llama_model & model,
|
||||
const int il) {
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inpSA,
|
||||
const llama_model & model,
|
||||
const int il) {
|
||||
|
||||
// For Granite architectures - scale residual
|
||||
if (hparams.f_residual_scale) {
|
||||
@@ -19378,7 +19378,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
case LLM_ARCH_NOMIC_BERT_MOE:
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
@@ -19671,7 +19671,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
{
|
||||
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
|
||||
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
@@ -20014,6 +20014,7 @@ int32_t llama_n_head(const llama_model * model) {
|
||||
llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
switch (model->arch) {
|
||||
// these models do not use RoPE
|
||||
case LLM_ARCH_CLIP:
|
||||
case LLM_ARCH_GPT2:
|
||||
case LLM_ARCH_GPTJ:
|
||||
case LLM_ARCH_MPT:
|
||||
|
||||
+7
-1
@@ -701,6 +701,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
});
|
||||
}
|
||||
|
||||
bool is_clip_model = false;
|
||||
for (const auto * it : tensors) {
|
||||
const struct ggml_tensor * tensor = it->tensor;
|
||||
|
||||
@@ -714,12 +715,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||
qs.has_output = true;
|
||||
}
|
||||
|
||||
is_clip_model |= name.rfind("mm.", 0) == 0; // check the "mm." prefix
|
||||
}
|
||||
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0)
|
||||
if (qs.n_attention_wv != 0 && !is_clip_model)
|
||||
{
|
||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||
// attention layers have a non-zero number of kv heads
|
||||
@@ -881,6 +884,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// do not quantize relative position bias (T5)
|
||||
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||
|
||||
// do not quantize specific multimodal tensors
|
||||
quantize &= name.find(".position_embd.") == std::string::npos;
|
||||
|
||||
ggml_type new_type;
|
||||
void * new_data;
|
||||
size_t new_size;
|
||||
|
||||
@@ -2541,8 +2541,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||
if (n_non_eog == 0) {
|
||||
cur_p->size = 1;
|
||||
cur_p->data[0].id = ctx->vocab->token_eot();
|
||||
if (cur_p->data[0].id == LLAMA_TOKEN_NULL) {
|
||||
cur_p->data[0].id = ctx->vocab->token_eos();
|
||||
}
|
||||
cur_p->data[0].logit = 1.0f;
|
||||
|
||||
GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -2171,6 +2171,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<|end|>"
|
||||
|| t.first == "<end_of_turn>"
|
||||
|| t.first == "<|endoftext|>"
|
||||
|| t.first == "<|end_of_text|>" // granite
|
||||
|| t.first == "<EOT>"
|
||||
|| t.first == "_<EOT>"
|
||||
|| t.first == "<|end▁of▁sentence|>" // DeepSeek
|
||||
|
||||
@@ -124,6 +124,9 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
|
||||
} catch(const std::exception & e) {
|
||||
throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
|
||||
}
|
||||
if (model.arch == LLM_ARCH_CLIP) {
|
||||
throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead");
|
||||
}
|
||||
try {
|
||||
model.load_vocab(ml);
|
||||
} catch(const std::exception & e) {
|
||||
@@ -312,6 +315,7 @@ struct llama_model * llama_model_load_from_splits(
|
||||
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
splits.reserve(n_paths);
|
||||
for (size_t i = 0; i < n_paths; ++i) {
|
||||
splits.push_back(paths[i]);
|
||||
}
|
||||
|
||||
@@ -6779,7 +6779,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
|
||||
// run fewer test cases permuted
|
||||
@@ -6911,7 +6911,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
}
|
||||
|
||||
// qwen3-30b-a3b
|
||||
for (int bs : {1, 4, 8, 32, 64, 128, 512}) {
|
||||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
|
||||
@@ -6919,6 +6919,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
}
|
||||
}
|
||||
|
||||
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// gpt-oss-20b
|
||||
for (int bs : {1, 4, 8, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
|
||||
|
||||
@@ -524,6 +524,64 @@ static void test_json_with_dumped_args() {
|
||||
R"({"foo": "bar", "args": {"arg1": [)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":["})"
|
||||
);
|
||||
|
||||
// Unicode tests
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u0)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u0"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u00)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u00"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u000)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u000"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u0000)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u0000"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud8)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud8"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud80)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud80"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\u)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\u"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\ud)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\ud"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc0)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc0"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc00)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc00"})"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_positions() {
|
||||
|
||||
@@ -58,7 +58,7 @@ static void test_json_healing() {
|
||||
for (const auto & input : inputs) {
|
||||
common_json out;
|
||||
assert_equals(true, common_json_parse(input, "$foo", out));
|
||||
assert_equals<std::string>(expected, out.json.dump());
|
||||
assert_equals<std::string>(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true));
|
||||
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
|
||||
}
|
||||
};
|
||||
@@ -228,6 +228,56 @@ static void test_json_healing() {
|
||||
R"({"key":"$foo"})",
|
||||
R"(:"$foo)"
|
||||
);
|
||||
// Test unicode escape sequences
|
||||
test(
|
||||
{
|
||||
R"({"a":"\u)",
|
||||
},
|
||||
R"({"a":"\u0000$foo"})",
|
||||
R"(0000$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\u00)",
|
||||
},
|
||||
R"({"a":"\u0000$foo"})",
|
||||
R"(00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud300)",
|
||||
},
|
||||
R"({"a":"\ud300$foo"})",
|
||||
R"($foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(\udc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(udc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\u)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(dc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\udc00)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"($foo)"
|
||||
);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
Binary file not shown.
+32
-22
@@ -1585,23 +1585,31 @@ struct server_prompt_cache {
|
||||
}
|
||||
}
|
||||
|
||||
// average size per token
|
||||
const float size_per_token = std::max<float>(1.0f, float(size()) / (std::max<size_t>(1, n_tokens())));
|
||||
|
||||
// dynamically increase the token limit if it can fit in the memory limit
|
||||
const size_t limit_tokens_cur = limit_size > 0 ? std::max<size_t>(limit_tokens, limit_size/size_per_token) : limit_tokens;
|
||||
|
||||
if (limit_tokens > 0) {
|
||||
while (states.size() > 1 && n_tokens() > limit_tokens) {
|
||||
while (states.size() > 1 && n_tokens() > limit_tokens_cur) {
|
||||
if (states.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));
|
||||
SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n",
|
||||
limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0));
|
||||
|
||||
states.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n",
|
||||
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens);
|
||||
SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n",
|
||||
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur);
|
||||
|
||||
for (const auto & state : states) {
|
||||
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
||||
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n",
|
||||
(const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -3727,7 +3735,7 @@ struct server_context {
|
||||
}
|
||||
} else {
|
||||
if (slot.n_prompt_tokens() >= slot.n_ctx) {
|
||||
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
@@ -3804,7 +3812,7 @@ struct server_context {
|
||||
if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||
}
|
||||
|
||||
@@ -3852,7 +3860,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
const auto it = std::find_if(
|
||||
@@ -4020,7 +4028,7 @@ struct server_context {
|
||||
}
|
||||
}
|
||||
|
||||
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
|
||||
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
|
||||
|
||||
@@ -4226,7 +4234,7 @@ struct server_context {
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
@@ -4368,7 +4376,7 @@ struct server_context {
|
||||
|
||||
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
|
||||
// skip GH copilot requests when using default port
|
||||
if (req.path == "/v1/health" || req.path == "/v1/completions") {
|
||||
if (req.path == "/v1/health") {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4955,9 +4963,17 @@ int main(int argc, char ** argv) {
|
||||
// Everything else, including multimodal completions.
|
||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||
}
|
||||
|
||||
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
|
||||
tasks.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto n_prompt_tokens = inputs[i].size();
|
||||
if (n_prompt_tokens >= n_ctx_slot) {
|
||||
json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||
error_data["n_prompt_tokens"] = n_prompt_tokens;
|
||||
error_data["n_ctx"] = n_ctx_slot;
|
||||
res_error(res, error_data);
|
||||
return;
|
||||
}
|
||||
server_task task = server_task(type);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
@@ -5393,15 +5409,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
// TODO: implement
|
||||
//int top_n = 1;
|
||||
//if (body.count("top_n") != 1) {
|
||||
// top_n = body.at("top_n");
|
||||
//} else {
|
||||
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
// return;
|
||||
//}
|
||||
|
||||
// if true, use TEI API format, otherwise use Jina API format
|
||||
// Jina: https://jina.ai/reranker/
|
||||
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
|
||||
@@ -5426,6 +5433,8 @@ int main(int argc, char ** argv) {
|
||||
return;
|
||||
}
|
||||
|
||||
int top_n = json_value(body, "top_n", (int)documents.size());
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
@@ -5466,7 +5475,8 @@ int main(int argc, char ** argv) {
|
||||
body,
|
||||
responses,
|
||||
is_tei_format,
|
||||
documents);
|
||||
documents,
|
||||
top_n);
|
||||
|
||||
res_ok(res, root);
|
||||
};
|
||||
|
||||
@@ -408,6 +408,28 @@ def test_context_size_exceeded():
|
||||
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
|
||||
|
||||
|
||||
def test_context_size_exceeded_stream():
|
||||
global server
|
||||
server.start()
|
||||
try:
|
||||
for _ in server.make_stream_request("POST", "/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
] * 100, # make the prompt too long
|
||||
"stream": True}):
|
||||
pass
|
||||
assert False, "Should have failed"
|
||||
except ServerError as e:
|
||||
assert e.code == 400
|
||||
assert "error" in e.body
|
||||
assert e.body["error"]["type"] == "exceed_context_size_error"
|
||||
assert e.body["error"]["n_prompt_tokens"] > 0
|
||||
assert server.n_ctx is not None
|
||||
assert server.n_slots is not None
|
||||
assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_batch,batch_count,reuse_cache",
|
||||
[
|
||||
|
||||
@@ -102,3 +102,45 @@ def test_rerank_usage(query, doc1, doc2, n_tokens):
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_n,expected_len", [
|
||||
(None, len(TEST_DOCUMENTS)), # no top_n parameter
|
||||
(2, 2),
|
||||
(4, 4),
|
||||
(99, len(TEST_DOCUMENTS)), # higher than available docs
|
||||
])
|
||||
def test_rerank_top_n(top_n, expected_len):
|
||||
global server
|
||||
server.start()
|
||||
data = {
|
||||
"query": "Machine learning is",
|
||||
"documents": TEST_DOCUMENTS,
|
||||
}
|
||||
if top_n is not None:
|
||||
data["top_n"] = top_n
|
||||
|
||||
res = server.make_request("POST", "/rerank", data=data)
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["results"]) == expected_len
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_n,expected_len", [
|
||||
(None, len(TEST_DOCUMENTS)), # no top_n parameter
|
||||
(2, 2),
|
||||
(4, 4),
|
||||
(99, len(TEST_DOCUMENTS)), # higher than available docs
|
||||
])
|
||||
def test_rerank_tei_top_n(top_n, expected_len):
|
||||
global server
|
||||
server.start()
|
||||
data = {
|
||||
"query": "Machine learning is",
|
||||
"texts": TEST_DOCUMENTS,
|
||||
}
|
||||
if top_n is not None:
|
||||
data["top_n"] = top_n
|
||||
|
||||
res = server.make_request("POST", "/rerank", data=data)
|
||||
assert res.status_code == 200
|
||||
assert len(res.body) == expected_len
|
||||
|
||||
@@ -35,6 +35,12 @@ class ServerResponse:
|
||||
body: dict | Any
|
||||
|
||||
|
||||
class ServerError(Exception):
|
||||
def __init__(self, code, body):
|
||||
self.code = code
|
||||
self.body = body
|
||||
|
||||
|
||||
class ServerProcess:
|
||||
# default options
|
||||
debug: bool = False
|
||||
@@ -297,6 +303,8 @@ class ServerProcess:
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
else:
|
||||
raise ValueError(f"Unimplemented method: {method}")
|
||||
if response.status_code != 200:
|
||||
raise ServerError(response.status_code, response.json())
|
||||
for line_bytes in response.iter_lines():
|
||||
line = line_bytes.decode("utf-8")
|
||||
if '[DONE]' in line:
|
||||
|
||||
+38
-40
@@ -849,47 +849,44 @@ static json format_response_rerank(
|
||||
const json & request,
|
||||
const json & ranks,
|
||||
bool is_tei_format,
|
||||
std::vector<std::string> & texts) {
|
||||
json res;
|
||||
if (is_tei_format) {
|
||||
// TEI response format
|
||||
res = json::array();
|
||||
bool return_text = json_value(request, "return_text", false);
|
||||
for (const auto & rank : ranks) {
|
||||
int index = json_value(rank, "index", 0);
|
||||
json elem = json{
|
||||
{"index", index},
|
||||
{"score", json_value(rank, "score", 0.0)},
|
||||
};
|
||||
if (return_text) {
|
||||
elem["text"] = std::move(texts[index]);
|
||||
}
|
||||
res.push_back(elem);
|
||||
}
|
||||
} else {
|
||||
// Jina response format
|
||||
json results = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
for (const auto & rank : ranks) {
|
||||
results.push_back(json{
|
||||
{"index", json_value(rank, "index", 0)},
|
||||
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||
});
|
||||
|
||||
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
||||
}
|
||||
|
||||
res = json{
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", "list"},
|
||||
{"usage", json{
|
||||
{"prompt_tokens", n_tokens},
|
||||
{"total_tokens", n_tokens}
|
||||
}},
|
||||
{"results", results}
|
||||
std::vector<std::string> & texts,
|
||||
int top_n) {
|
||||
int32_t n_tokens = 0;
|
||||
bool return_text = is_tei_format && json_value(request, "return_text", false);
|
||||
std::vector<json> elements; // Temporary vector to hold unsorted elements
|
||||
std::string score_label = is_tei_format ? "score" : "relevance_score";
|
||||
for (const auto & rank : ranks) {
|
||||
int index = json_value(rank, "index", 0);
|
||||
json elem = json{
|
||||
{"index", index},
|
||||
{score_label, json_value(rank, "score", 0.0)},
|
||||
};
|
||||
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
||||
if (return_text) {
|
||||
elem["text"] = std::move(texts[index]);
|
||||
}
|
||||
elements.push_back(elem);
|
||||
}
|
||||
|
||||
std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) {
|
||||
return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0);
|
||||
});
|
||||
|
||||
elements.resize(std::min(top_n, (int)elements.size()));
|
||||
json results = elements;
|
||||
|
||||
if (is_tei_format) return results;
|
||||
|
||||
json res = json{
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", "list"},
|
||||
{"usage", json{
|
||||
{"prompt_tokens", n_tokens},
|
||||
{"total_tokens", n_tokens}
|
||||
}},
|
||||
{"results", results}
|
||||
};
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -1240,9 +1237,10 @@ public:
|
||||
// allowed to resize ^ ^
|
||||
// disallowed to resize ^ ^ ^
|
||||
if (n > 0) {
|
||||
llama_token last_token = tokens[n - 1];
|
||||
// make sure we never remove tokens in the middle of an image
|
||||
if (last_token == LLAMA_TOKEN_NULL) {
|
||||
// note that the case where we keep a full image at the end is allowed:
|
||||
// tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL
|
||||
if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) {
|
||||
find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk
|
||||
}
|
||||
}
|
||||
|
||||
Generated
+69
@@ -50,6 +50,7 @@
|
||||
"eslint-plugin-svelte": "^3.0.0",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^16.0.0",
|
||||
"mdast": "^3.0.0",
|
||||
"mdsvex": "^0.12.3",
|
||||
"playwright": "^1.53.0",
|
||||
"prettier": "^3.4.2",
|
||||
@@ -66,6 +67,7 @@
|
||||
"tw-animate-css": "^1.3.5",
|
||||
"typescript": "^5.0.0",
|
||||
"typescript-eslint": "^8.20.0",
|
||||
"unified": "^11.0.5",
|
||||
"uuid": "^13.0.0",
|
||||
"vite": "^7.0.4",
|
||||
"vite-plugin-devtools-json": "^0.2.0",
|
||||
@@ -2128,6 +2130,66 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/core": {
|
||||
"version": "1.4.3",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@emnapi/wasi-threads": "1.0.2",
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/runtime": {
|
||||
"version": "1.4.3",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/wasi-threads": {
|
||||
"version": "1.0.2",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@napi-rs/wasm-runtime": {
|
||||
"version": "0.2.11",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@emnapi/core": "^1.4.3",
|
||||
"@emnapi/runtime": "^1.4.3",
|
||||
"@tybys/wasm-util": "^0.9.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@tybys/wasm-util": {
|
||||
"version": "0.9.0",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/tslib": {
|
||||
"version": "2.8.0",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "0BSD",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-win32-arm64-msvc": {
|
||||
"version": "4.1.11",
|
||||
"resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.11.tgz",
|
||||
@@ -4946,6 +5008,13 @@
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/mdast": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/mdast/-/mdast-3.0.0.tgz",
|
||||
"integrity": "sha512-xySmf8g4fPKMeC07jXGz971EkLbWAJ83s4US2Tj9lEdnZ142UP5grN73H1Xd3HzrdbU5o9GYYP/y8F9ZSwLE9g==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/mdast-util-find-and-replace": {
|
||||
"version": "3.0.2",
|
||||
"resolved": "https://registry.npmjs.org/mdast-util-find-and-replace/-/mdast-util-find-and-replace-3.0.2.tgz",
|
||||
|
||||
@@ -52,6 +52,7 @@
|
||||
"eslint-plugin-svelte": "^3.0.0",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^16.0.0",
|
||||
"mdast": "^3.0.0",
|
||||
"mdsvex": "^0.12.3",
|
||||
"playwright": "^1.53.0",
|
||||
"prettier": "^3.4.2",
|
||||
@@ -68,6 +69,7 @@
|
||||
"tw-animate-css": "^1.3.5",
|
||||
"typescript": "^5.0.0",
|
||||
"typescript-eslint": "^8.20.0",
|
||||
"unified": "^11.0.5",
|
||||
"uuid": "^13.0.0",
|
||||
"vite": "^7.0.4",
|
||||
"vite-plugin-devtools-json": "^0.2.0",
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
import { Check, X } from '@lucide/svelte';
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { ChatAttachmentsList } from '$lib/components/app';
|
||||
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
|
||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -55,6 +56,7 @@
|
||||
|
||||
let isMultiline = $state(false);
|
||||
let messageElement: HTMLElement | undefined = $state();
|
||||
const currentConfig = config();
|
||||
|
||||
$effect(() => {
|
||||
if (!messageElement || !message.content.trim()) return;
|
||||
@@ -123,9 +125,18 @@
|
||||
class="max-w-[80%] rounded-[1.125rem] bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
>
|
||||
<span bind:this={messageElement} class="text-md whitespace-pre-wrap">
|
||||
{message.content}
|
||||
</span>
|
||||
{#if currentConfig.renderUserContentAsMarkdown}
|
||||
<div bind:this={messageElement} class="text-md">
|
||||
<MarkdownContent
|
||||
class="markdown-user-content text-primary-foreground"
|
||||
content={message.content}
|
||||
/>
|
||||
</div>
|
||||
{:else}
|
||||
<span bind:this={messageElement} class="text-md whitespace-pre-wrap">
|
||||
{message.content}
|
||||
</span>
|
||||
{/if}
|
||||
</Card>
|
||||
{/if}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
ChatMessages,
|
||||
ChatProcessingInfo,
|
||||
EmptyFileAlertDialog,
|
||||
ChatErrorDialog,
|
||||
ServerErrorSplash,
|
||||
ServerInfo,
|
||||
ServerLoadingSplash,
|
||||
@@ -22,10 +23,11 @@
|
||||
activeMessages,
|
||||
activeConversation,
|
||||
deleteConversation,
|
||||
dismissErrorDialog,
|
||||
errorDialog,
|
||||
isLoading,
|
||||
sendMessage,
|
||||
stopGeneration,
|
||||
setMaxContextError
|
||||
stopGeneration
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import {
|
||||
supportsVision,
|
||||
@@ -34,7 +36,6 @@
|
||||
serverWarning,
|
||||
serverStore
|
||||
} from '$lib/stores/server.svelte';
|
||||
import { contextService } from '$lib/services';
|
||||
import { parseFilesToMessageExtras } from '$lib/utils/convert-files-to-extra';
|
||||
import { isFileTypeSupported } from '$lib/utils/file-type';
|
||||
import { filterFilesByModalities } from '$lib/utils/modality-file-validation';
|
||||
@@ -79,6 +80,7 @@
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
|
||||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
|
||||
async function handleDeleteConfirm() {
|
||||
@@ -105,6 +107,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleErrorDialogOpenChange(open: boolean) {
|
||||
if (!open) {
|
||||
dismissErrorDialog();
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
}
|
||||
@@ -183,21 +191,6 @@
|
||||
|
||||
const extras = result?.extras;
|
||||
|
||||
// Check context limit using real-time slots data
|
||||
const contextCheck = await contextService.checkContextLimit();
|
||||
|
||||
if (contextCheck && contextCheck.wouldExceed) {
|
||||
const errorMessage = contextService.getContextErrorMessage(contextCheck);
|
||||
|
||||
setMaxContextError({
|
||||
message: errorMessage,
|
||||
estimatedTokens: contextCheck.currentUsage,
|
||||
maxContext: contextCheck.maxContext
|
||||
});
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Enable autoscroll for user-initiated message sending
|
||||
userScrolledUp = false;
|
||||
autoScrollEnabled = true;
|
||||
@@ -461,6 +454,13 @@
|
||||
}}
|
||||
/>
|
||||
|
||||
<ChatErrorDialog
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? 'server'}
|
||||
/>
|
||||
|
||||
<style>
|
||||
.conversation-chat-form {
|
||||
position: relative;
|
||||
|
||||
@@ -80,6 +80,11 @@
|
||||
key: 'showModelInfo',
|
||||
label: 'Show model information',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'renderUserContentAsMarkdown',
|
||||
label: 'Render user content as Markdown',
|
||||
type: 'checkbox'
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
<script lang="ts">
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { AlertTriangle, TimerOff } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
type: 'timeout' | 'server';
|
||||
message: string;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
let { open = $bindable(), type, message, onOpenChange }: Props = $props();
|
||||
|
||||
const isTimeout = $derived(type === 'timeout');
|
||||
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
|
||||
const description = $derived(
|
||||
isTimeout
|
||||
? 'The request did not receive a response from the server before timing out.'
|
||||
: 'The server responded with an error message. Review the details below.'
|
||||
);
|
||||
const iconClass = $derived(isTimeout ? 'text-destructive' : 'text-amber-500');
|
||||
const badgeClass = $derived(
|
||||
isTimeout
|
||||
? 'border-destructive/40 bg-destructive/10 text-destructive'
|
||||
: 'border-amber-500/40 bg-amber-500/10 text-amber-600 dark:text-amber-400'
|
||||
);
|
||||
|
||||
function handleOpenChange(newOpen: boolean) {
|
||||
open = newOpen;
|
||||
onOpenChange?.(newOpen);
|
||||
}
|
||||
</script>
|
||||
|
||||
<AlertDialog.Root {open} onOpenChange={handleOpenChange}>
|
||||
<AlertDialog.Content>
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title class="flex items-center gap-2">
|
||||
{#if isTimeout}
|
||||
<TimerOff class={`h-5 w-5 ${iconClass}`} />
|
||||
{:else}
|
||||
<AlertTriangle class={`h-5 w-5 ${iconClass}`} />
|
||||
{/if}
|
||||
|
||||
{title}
|
||||
</AlertDialog.Title>
|
||||
|
||||
<AlertDialog.Description>
|
||||
{description}
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
|
||||
<div class={`rounded-lg border px-4 py-3 text-sm ${badgeClass}`}>
|
||||
<p class="font-medium">{message}</p>
|
||||
</div>
|
||||
|
||||
<AlertDialog.Footer>
|
||||
<AlertDialog.Action onclick={() => handleOpenChange(false)}>Close</AlertDialog.Action>
|
||||
</AlertDialog.Footer>
|
||||
</AlertDialog.Content>
|
||||
</AlertDialog.Root>
|
||||
@@ -1,66 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { AlertTriangle } from '@lucide/svelte';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { maxContextError, clearMaxContextError } from '$lib/stores/chat.svelte';
|
||||
</script>
|
||||
|
||||
<AlertDialog.Root
|
||||
open={maxContextError() !== null}
|
||||
onOpenChange={(open) => !open && clearMaxContextError()}
|
||||
>
|
||||
<AlertDialog.Content>
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title class="flex items-center gap-2">
|
||||
<AlertTriangle class="h-5 w-5 text-destructive" />
|
||||
|
||||
Message Too Long
|
||||
</AlertDialog.Title>
|
||||
|
||||
<AlertDialog.Description>
|
||||
Your message exceeds the model's context window and cannot be processed.
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
|
||||
{#if maxContextError()}
|
||||
<div class="space-y-3 text-sm">
|
||||
<div class="rounded-lg bg-muted p-3">
|
||||
<div class="mb-2 font-medium">Token Usage:</div>
|
||||
|
||||
<div class="space-y-1 text-muted-foreground">
|
||||
<div>
|
||||
Estimated tokens:
|
||||
|
||||
<span class="font-mono">
|
||||
{maxContextError()?.estimatedTokens.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
Context window:
|
||||
|
||||
<span class="font-mono">
|
||||
{maxContextError()?.maxContext.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div class="mb-2 font-medium">Suggestions:</div>
|
||||
|
||||
<ul class="list-inside list-disc space-y-1 text-muted-foreground">
|
||||
<li>Shorten your message</li>
|
||||
|
||||
<li>Remove some file attachments</li>
|
||||
|
||||
<li>Start a new conversation</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<AlertDialog.Footer>
|
||||
<AlertDialog.Action onclick={() => clearMaxContextError()}>Got it</AlertDialog.Action>
|
||||
</AlertDialog.Footer>
|
||||
</AlertDialog.Content>
|
||||
</AlertDialog.Root>
|
||||
@@ -30,12 +30,11 @@ export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
|
||||
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
|
||||
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
|
||||
|
||||
export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte';
|
||||
export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte';
|
||||
|
||||
export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte';
|
||||
|
||||
export { default as MaximumContextAlertDialog } from './dialogs/MaximumContextAlertDialog.svelte';
|
||||
|
||||
export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte';
|
||||
|
||||
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
import { mode } from 'mode-watcher';
|
||||
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
@@ -50,36 +51,59 @@
|
||||
.use(remarkGfm) // GitHub Flavored Markdown
|
||||
.use(remarkMath) // Parse $inline$ and $$block$$ math
|
||||
.use(remarkBreaks) // Convert line breaks to <br>
|
||||
.use(remarkRehype) // Convert to rehype (HTML AST)
|
||||
.use(remarkLiteralHtml) // Treat raw HTML as literal text with preserved indentation
|
||||
.use(remarkRehype) // Convert Markdown AST to rehype
|
||||
.use(rehypeKatex) // Render math using KaTeX
|
||||
.use(rehypeHighlight) // Add syntax highlighting
|
||||
.use(rehypeStringify); // Convert to HTML string
|
||||
});
|
||||
|
||||
function enhanceLinks(html: string): string {
|
||||
if (!html.includes('<a')) {
|
||||
return html;
|
||||
}
|
||||
|
||||
const tempDiv = document.createElement('div');
|
||||
tempDiv.innerHTML = html;
|
||||
|
||||
// Make all links open in new tabs
|
||||
const linkElements = tempDiv.querySelectorAll('a[href]');
|
||||
let mutated = false;
|
||||
|
||||
for (const link of linkElements) {
|
||||
const target = link.getAttribute('target');
|
||||
const rel = link.getAttribute('rel');
|
||||
|
||||
if (target !== '_blank' || rel !== 'noopener noreferrer') {
|
||||
mutated = true;
|
||||
}
|
||||
|
||||
link.setAttribute('target', '_blank');
|
||||
link.setAttribute('rel', 'noopener noreferrer');
|
||||
}
|
||||
|
||||
return tempDiv.innerHTML;
|
||||
return mutated ? tempDiv.innerHTML : html;
|
||||
}
|
||||
|
||||
function enhanceCodeBlocks(html: string): string {
|
||||
if (!html.includes('<pre')) {
|
||||
return html;
|
||||
}
|
||||
|
||||
const tempDiv = document.createElement('div');
|
||||
tempDiv.innerHTML = html;
|
||||
|
||||
const preElements = tempDiv.querySelectorAll('pre');
|
||||
let mutated = false;
|
||||
|
||||
for (const [index, pre] of Array.from(preElements).entries()) {
|
||||
const codeElement = pre.querySelector('code');
|
||||
|
||||
if (!codeElement) continue;
|
||||
if (!codeElement) {
|
||||
continue;
|
||||
}
|
||||
|
||||
mutated = true;
|
||||
|
||||
let language = 'text';
|
||||
const classList = Array.from(codeElement.classList);
|
||||
@@ -127,7 +151,7 @@
|
||||
pre.parentNode?.replaceChild(wrapper, pre);
|
||||
}
|
||||
|
||||
return tempDiv.innerHTML;
|
||||
return mutated ? tempDiv.innerHTML : html;
|
||||
}
|
||||
|
||||
async function processMarkdown(text: string): Promise<string> {
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
export const LINE_BREAK = /\r?\n/;
|
||||
|
||||
export const PHRASE_PARENTS = new Set([
|
||||
'paragraph',
|
||||
'heading',
|
||||
'emphasis',
|
||||
'strong',
|
||||
'delete',
|
||||
'link',
|
||||
'linkReference',
|
||||
'tableCell'
|
||||
]);
|
||||
|
||||
export const NBSP = '\u00a0';
|
||||
export const TAB_AS_SPACES = NBSP.repeat(4);
|
||||
@@ -12,6 +12,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||
pasteLongTextToFileLen: 2500,
|
||||
pdfAsImage: false,
|
||||
showModelInfo: false,
|
||||
renderUserContentAsMarkdown: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'top_k;typ_p;top_p;min_p;temperature',
|
||||
temperature: 0.8,
|
||||
@@ -84,6 +85,7 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
||||
pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).',
|
||||
showModelInfo: 'Display the model name used to generate each message below the message content.',
|
||||
renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.',
|
||||
pyInterpreterEnabled:
|
||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.'
|
||||
};
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
import type { Plugin } from 'unified';
|
||||
import { visit } from 'unist-util-visit';
|
||||
import type { Break, Content, Paragraph, PhrasingContent, Root, Text } from 'mdast';
|
||||
import { LINE_BREAK, NBSP, PHRASE_PARENTS, TAB_AS_SPACES } from '$lib/constants/literal-html';
|
||||
|
||||
/**
|
||||
* remark plugin that rewrites raw HTML nodes into plain-text equivalents.
|
||||
*
|
||||
* remark parses inline HTML into `html` nodes even when we do not want to render
|
||||
* them. We turn each of those nodes into regular text (plus `<br>` break markers)
|
||||
* so the downstream rehype pipeline escapes the characters instead of executing
|
||||
* them. Leading spaces and tab characters are converted to non‑breaking spaces to
|
||||
* keep indentation identical to the original author input.
|
||||
*/
|
||||
|
||||
function preserveIndent(line: string): string {
|
||||
let index = 0;
|
||||
let output = '';
|
||||
|
||||
while (index < line.length) {
|
||||
const char = line[index];
|
||||
|
||||
if (char === ' ') {
|
||||
output += NBSP;
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (char === '\t') {
|
||||
output += TAB_AS_SPACES;
|
||||
index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
return output + line.slice(index);
|
||||
}
|
||||
|
||||
function createLiteralChildren(value: string): PhrasingContent[] {
|
||||
const lines = value.split(LINE_BREAK);
|
||||
const nodes: PhrasingContent[] = [];
|
||||
|
||||
for (const [lineIndex, rawLine] of lines.entries()) {
|
||||
if (lineIndex > 0) {
|
||||
nodes.push({ type: 'break' } as Break as unknown as PhrasingContent);
|
||||
}
|
||||
|
||||
nodes.push({
|
||||
type: 'text',
|
||||
value: preserveIndent(rawLine)
|
||||
} as Text as unknown as PhrasingContent);
|
||||
}
|
||||
|
||||
if (!nodes.length) {
|
||||
nodes.push({ type: 'text', value: '' } as Text as unknown as PhrasingContent);
|
||||
}
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
export const remarkLiteralHtml: Plugin<[], Root> = () => {
|
||||
return (tree) => {
|
||||
visit(tree, 'html', (node, index, parent) => {
|
||||
if (!parent || typeof index !== 'number') {
|
||||
return;
|
||||
}
|
||||
|
||||
const replacement = createLiteralChildren(node.value);
|
||||
|
||||
if (!PHRASE_PARENTS.has(parent.type as string)) {
|
||||
const paragraph: Paragraph = {
|
||||
type: 'paragraph',
|
||||
children: replacement as Paragraph['children'],
|
||||
data: { literalHtml: true }
|
||||
};
|
||||
|
||||
const siblings = parent.children as unknown as Content[];
|
||||
siblings.splice(index, 1, paragraph as unknown as Content);
|
||||
|
||||
if (index > 0) {
|
||||
const previous = siblings[index - 1] as Paragraph | undefined;
|
||||
|
||||
if (
|
||||
previous?.type === 'paragraph' &&
|
||||
(previous.data as { literalHtml?: boolean } | undefined)?.literalHtml
|
||||
) {
|
||||
const prevChildren = previous.children as unknown as PhrasingContent[];
|
||||
|
||||
if (prevChildren.length) {
|
||||
const lastChild = prevChildren[prevChildren.length - 1];
|
||||
|
||||
if (lastChild.type !== 'break') {
|
||||
prevChildren.push({
|
||||
type: 'break'
|
||||
} as Break as unknown as PhrasingContent);
|
||||
}
|
||||
}
|
||||
|
||||
prevChildren.push(...(paragraph.children as unknown as PhrasingContent[]));
|
||||
|
||||
siblings.splice(index, 1);
|
||||
|
||||
return index;
|
||||
}
|
||||
}
|
||||
|
||||
return index + 1;
|
||||
}
|
||||
|
||||
(parent.children as unknown as PhrasingContent[]).splice(
|
||||
index,
|
||||
1,
|
||||
...(replacement as unknown as PhrasingContent[])
|
||||
);
|
||||
|
||||
return index + replacement.length;
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -13,7 +13,7 @@ import { slotsService } from './slots';
|
||||
* - Manages streaming and non-streaming response parsing
|
||||
* - Provides request abortion capabilities
|
||||
* - Converts database messages to API format
|
||||
* - Handles error translation and context detection
|
||||
* - Handles error translation for server responses
|
||||
*
|
||||
* - **ChatStore**: Stateful orchestration and UI state management
|
||||
* - Uses ChatService for all AI model communication
|
||||
@@ -26,7 +26,6 @@ import { slotsService } from './slots';
|
||||
* - Streaming response handling with real-time callbacks
|
||||
* - Reasoning content extraction and processing
|
||||
* - File attachment processing (images, PDFs, audio, text)
|
||||
* - Context error detection and reporting
|
||||
* - Request lifecycle management (abort, cleanup)
|
||||
*/
|
||||
export class ChatService {
|
||||
@@ -209,10 +208,13 @@ export class ChatService {
|
||||
userFriendlyError = new Error(
|
||||
'Unable to connect to server - please check if the server is running'
|
||||
);
|
||||
userFriendlyError.name = 'NetworkError';
|
||||
} else if (error.message.includes('ECONNREFUSED')) {
|
||||
userFriendlyError = new Error('Connection refused - server may be offline');
|
||||
userFriendlyError.name = 'NetworkError';
|
||||
} else if (error.message.includes('ETIMEDOUT')) {
|
||||
userFriendlyError = new Error('Request timeout - server may be overloaded');
|
||||
userFriendlyError = new Error('Request timed out - the server took too long to respond');
|
||||
userFriendlyError.name = 'TimeoutError';
|
||||
} else {
|
||||
userFriendlyError = error;
|
||||
}
|
||||
@@ -262,6 +264,7 @@ export class ChatService {
|
||||
let fullReasoningContent = '';
|
||||
let hasReceivedData = false;
|
||||
let lastTimings: ChatMessageTimings | undefined;
|
||||
let streamFinished = false;
|
||||
|
||||
try {
|
||||
let chunk = '';
|
||||
@@ -277,18 +280,8 @@ export class ChatService {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6);
|
||||
if (data === '[DONE]') {
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
return;
|
||||
}
|
||||
|
||||
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
|
||||
|
||||
return;
|
||||
streamFinished = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -326,13 +319,13 @@ export class ChatService {
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
return;
|
||||
if (streamFinished) {
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
|
||||
}
|
||||
} catch (error) {
|
||||
const err = error instanceof Error ? error : new Error('Stream error');
|
||||
@@ -368,12 +361,8 @@ export class ChatService {
|
||||
const responseText = await response.text();
|
||||
|
||||
if (!responseText.trim()) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
throw contextError;
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
const data: ApiChatCompletionResponse = JSON.parse(responseText);
|
||||
@@ -385,22 +374,14 @@ export class ChatService {
|
||||
}
|
||||
|
||||
if (!content.trim()) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
throw contextError;
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
onComplete?.(content, reasoningContent);
|
||||
|
||||
return content;
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === 'ContextError') {
|
||||
throw error;
|
||||
}
|
||||
|
||||
const err = error instanceof Error ? error : new Error('Parse error');
|
||||
|
||||
onError?.(err);
|
||||
@@ -594,37 +575,19 @@ export class ChatService {
|
||||
const errorText = await response.text();
|
||||
const errorData: ApiErrorResponse = JSON.parse(errorText);
|
||||
|
||||
if (errorData.error?.type === 'exceed_context_size_error') {
|
||||
const contextError = errorData.error as ApiContextSizeError;
|
||||
const error = new Error(contextError.message);
|
||||
error.name = 'ContextError';
|
||||
// Attach structured context information
|
||||
(
|
||||
error as Error & {
|
||||
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
|
||||
}
|
||||
).contextInfo = {
|
||||
promptTokens: contextError.n_prompt_tokens,
|
||||
maxContext: contextError.n_ctx,
|
||||
estimatedTokens: contextError.n_prompt_tokens
|
||||
};
|
||||
return error;
|
||||
}
|
||||
|
||||
// Fallback for other error types
|
||||
const message = errorData.error?.message || 'Unknown server error';
|
||||
return new Error(message);
|
||||
const error = new Error(message);
|
||||
error.name = response.status === 400 ? 'ServerError' : 'HttpError';
|
||||
|
||||
return error;
|
||||
} catch {
|
||||
// If we can't parse the error response, return a generic error
|
||||
return new Error(`Server error (${response.status}): ${response.statusText}`);
|
||||
const fallback = new Error(`Server error (${response.status}): ${response.statusText}`);
|
||||
fallback.name = 'HttpError';
|
||||
return fallback;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the processing state with timing information from the server response
|
||||
* @param timings - Timing data from the API response
|
||||
* @param promptProgress - Progress data from the API response
|
||||
*/
|
||||
private updateProcessingState(
|
||||
timings?: ChatMessageTimings,
|
||||
promptProgress?: ChatMessagePromptProgress
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
import { slotsService } from './slots';
|
||||
|
||||
export interface ContextCheckResult {
|
||||
wouldExceed: boolean;
|
||||
currentUsage: number;
|
||||
maxContext: number;
|
||||
availableTokens: number;
|
||||
reservedTokens: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* ContextService - Context window management and limit checking
|
||||
*
|
||||
* This service provides context window monitoring and limit checking using real-time
|
||||
* server data from the slots service. It helps prevent context overflow by tracking
|
||||
* current usage and calculating available space for new content.
|
||||
*
|
||||
* **Architecture & Relationships:**
|
||||
* - **ContextService** (this class): Context limit monitoring
|
||||
* - Uses SlotsService for real-time context usage data
|
||||
* - Calculates available tokens with configurable reserves
|
||||
* - Provides context limit checking and error messaging
|
||||
* - Helps prevent context window overflow
|
||||
*
|
||||
* - **SlotsService**: Provides current context usage from server slots
|
||||
* - **ChatStore**: Uses context checking before sending messages
|
||||
* - **UI Components**: Display context usage warnings and limits
|
||||
*
|
||||
* **Key Features:**
|
||||
* - **Real-time Context Checking**: Uses live server data for accuracy
|
||||
* - **Token Reservation**: Reserves tokens for response generation
|
||||
* - **Limit Detection**: Prevents context window overflow
|
||||
* - **Usage Reporting**: Detailed context usage statistics
|
||||
* - **Error Messaging**: User-friendly context limit messages
|
||||
* - **Configurable Reserves**: Adjustable token reservation for responses
|
||||
*
|
||||
* **Context Management:**
|
||||
* - Monitors current context usage from active slots
|
||||
* - Calculates available space considering reserved tokens
|
||||
* - Provides early warning before context limits are reached
|
||||
* - Helps optimize conversation length and content
|
||||
*/
|
||||
export class ContextService {
|
||||
private reserveTokens: number;
|
||||
|
||||
constructor(reserveTokens = 512) {
|
||||
this.reserveTokens = reserveTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the context limit would be exceeded
|
||||
*
|
||||
* @returns {Promise<ContextCheckResult | null>} Promise that resolves to the context check result or null if an error occurs
|
||||
*/
|
||||
async checkContextLimit(): Promise<ContextCheckResult | null> {
|
||||
try {
|
||||
const currentState = await slotsService.getCurrentState();
|
||||
|
||||
if (!currentState) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const maxContext = currentState.contextTotal;
|
||||
const currentUsage = currentState.contextUsed;
|
||||
const availableTokens = maxContext - currentUsage - this.reserveTokens;
|
||||
const wouldExceed = availableTokens <= 0;
|
||||
|
||||
return {
|
||||
wouldExceed,
|
||||
currentUsage,
|
||||
maxContext,
|
||||
availableTokens: Math.max(0, availableTokens),
|
||||
reservedTokens: this.reserveTokens
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn('Error checking context limit:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a formatted error message for context limit exceeded
|
||||
*
|
||||
* @param {ContextCheckResult} result - Context check result
|
||||
* @returns {string} Formatted error message
|
||||
*/
|
||||
getContextErrorMessage(result: ContextCheckResult): string {
|
||||
const usagePercent = Math.round((result.currentUsage / result.maxContext) * 100);
|
||||
return `Context window is nearly full. Current usage: ${result.currentUsage.toLocaleString()}/${result.maxContext.toLocaleString()} tokens (${usagePercent}%). Available space: ${result.availableTokens.toLocaleString()} tokens (${result.reservedTokens} reserved for response).`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of tokens to reserve for response generation
|
||||
*
|
||||
* @param {number} tokens - Number of tokens to reserve
|
||||
*/
|
||||
setReserveTokens(tokens: number): void {
|
||||
this.reserveTokens = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
export const contextService = new ContextService();
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user