mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-30 17:47:40 +02:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3f750f8d76 | |||
| c515fc5771 | |||
| f9bc66c3eb | |||
| a31cf36ad9 | |||
| 81d54bbfd5 | |||
| c7be9febcb | |||
| 8415f61e23 | |||
| 2c301e91ab | |||
| 4b2dae383d | |||
| 41aac5c69b | |||
| a2fba89a42 | |||
| 20cc625edc | |||
| 11f0af5504 | |||
| a3cb04744f | |||
| 4a8fbe0a5e | |||
| 31d0ff1869 | |||
| 97870e6497 | |||
| 477a66b035 | |||
| e60f01d941 |
@@ -387,6 +387,39 @@ jobs:
|
||||
cd build
|
||||
ctest -L main --verbose
|
||||
|
||||
ubuntu-24-cmake-vulkan-deb:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-24-cmake-vulkan-deb
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
|
||||
|
||||
- name: Configure
|
||||
id: cmake_configure
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DGGML_VULKAN=ON
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
ubuntu-24-cmake-vulkan:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
|
||||
+161
-133
@@ -3358,7 +3358,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
add_opt(common_arg(
|
||||
{"--chat-template-kwargs"}, "STRING",
|
||||
string_format("sets additional params for the json template parser"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
[](common_params & params, const std::string & value) {
|
||||
auto parsed = json::parse(value);
|
||||
for (const auto & item : parsed.items()) {
|
||||
params.default_template_kwargs[item.key()] = item.value().dump();
|
||||
@@ -3570,21 +3570,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
common_log_set_file(common_log_main(), value.c_str());
|
||||
}
|
||||
));
|
||||
add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params &, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
|
||||
} else if (is_falsey(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
|
||||
} else if (is_autoy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg(
|
||||
{"--log-colors"}, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params &, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
|
||||
} else if (is_falsey(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
|
||||
} else if (is_autoy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg(
|
||||
{"-v", "--verbose", "--log-verbose"},
|
||||
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)",
|
||||
@@ -3850,7 +3852,87 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||
|
||||
// model-specific
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-steps"}, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-visual"},
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)", params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-eps"}, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-algorithm"}, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-alg-temp"}, "F",
|
||||
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-block-length"}, "N",
|
||||
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
|
||||
[](common_params & params, int value) { params.diffusion.block_length = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-cfg-scale"}, "F",
|
||||
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-add-gumbel-noise"}, "F",
|
||||
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "-lr", "--learning-rate" }, "ALPHA",
|
||||
string_format("adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)", (double) params.lr.lr0),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
|
||||
string_format("(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
|
||||
(double) params.lr.lr_min),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-decay-epochs", "--learning-rate-decay-epochs"}, "ALPHA",
|
||||
string_format("(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)", (double) params.lr.decay_epochs),
|
||||
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-wd", "--weight-decay"}, "WD",
|
||||
string_format("adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).", (double) params.lr.wd),
|
||||
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-val-split", "--val-split"}, "FRACTION",
|
||||
string_format("fraction of data to use as validation set for training (default: %.2g).", (double) params.val_split),
|
||||
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-epochs", "--epochs"}, "N",
|
||||
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
|
||||
[](common_params & params, int epochs) { params.lr.epochs = epochs; }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-opt", "--optimizer"}, "sgd|adamw", "adamw or sgd",
|
||||
[](common_params & params, const std::string & name) {
|
||||
params.optimizer = common_opt_get_optimizer(name.c_str());
|
||||
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
|
||||
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
|
||||
}
|
||||
}
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
|
||||
// presets
|
||||
add_opt(common_arg(
|
||||
{"--tts-oute-default"},
|
||||
string_format("use default OuteTTS models (note: can download weights from the internet)"),
|
||||
@@ -3863,39 +3945,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-bge-small-en-default"},
|
||||
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
|
||||
{"--embd-gemma-default"},
|
||||
string_format("use default EmbeddingGemma model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
|
||||
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-e5-small-en-default"},
|
||||
string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
|
||||
params.model.hf_file = "e5-small-v2-q8_0.gguf";
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-gte-small-default"},
|
||||
string_format("use default gte-small model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
|
||||
params.model.hf_file = "gte-small-q8_0.gguf";
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.model.hf_repo = "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF";
|
||||
params.model.hf_file = "embeddinggemma-300M-qat-Q4_0.gguf";
|
||||
params.port = 8011;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 2048;
|
||||
params.n_parallel = 32;
|
||||
params.n_ctx = 2048*params.n_parallel;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
@@ -3990,96 +4049,65 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-steps" }, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-visual" },
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
|
||||
params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--gpt-oss-20b-default"},
|
||||
string_format("use gpt-oss-20b (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gpt-oss-20b-GGUF";
|
||||
params.model.hf_file = "gpt-oss-20b-mxfp4.gguf";
|
||||
params.port = 8013;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 32768;
|
||||
params.n_parallel = 2;
|
||||
params.n_ctx = 131072*params.n_parallel;
|
||||
params.sampling.temp = 1.0f;
|
||||
params.sampling.top_p = 1.0f;
|
||||
params.sampling.top_k = 0;
|
||||
params.sampling.min_p = 0.01f;
|
||||
params.use_jinja = true;
|
||||
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-eps" }, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-algorithm" }, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)",
|
||||
params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-alg-temp" }, "F",
|
||||
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--gpt-oss-120b-default"},
|
||||
string_format("use gpt-oss-120b (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gpt-oss-120b-GGUF";
|
||||
params.port = 8013;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 32768;
|
||||
params.n_parallel = 2;
|
||||
params.n_ctx = 131072*params.n_parallel;
|
||||
params.sampling.temp = 1.0f;
|
||||
params.sampling.top_p = 1.0f;
|
||||
params.sampling.top_k = 0;
|
||||
params.sampling.min_p = 0.01f;
|
||||
params.use_jinja = true;
|
||||
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-block-length" }, "N",
|
||||
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
|
||||
[](common_params & params, int value) { params.diffusion.block_length = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-cfg-scale" }, "F",
|
||||
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-add-gumbel-noise" }, "F",
|
||||
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--vision-gemma-4b-default"},
|
||||
string_format("use Gemma 3 4B QAT (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gemma-3-4b-it-qat-GGUF";
|
||||
params.port = 8014;
|
||||
params.n_ctx = 0;
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
|
||||
add_opt(
|
||||
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
|
||||
string_format(
|
||||
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
|
||||
(double) params.lr.lr0),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(
|
||||
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
|
||||
string_format(
|
||||
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
|
||||
(double) params.lr.lr_min),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(
|
||||
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
|
||||
string_format(
|
||||
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
|
||||
(double) params.lr.decay_epochs),
|
||||
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{ "-wd", "--weight-decay" }, "WD",
|
||||
string_format(
|
||||
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
|
||||
(double) params.lr.wd),
|
||||
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
|
||||
string_format("fraction of data to use as validation set for training (default: %.2g).",
|
||||
(double) params.val_split),
|
||||
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
|
||||
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
|
||||
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
|
||||
[](common_params & params, const std::string & name) {
|
||||
params.optimizer = common_opt_get_optimizer(name.c_str());
|
||||
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
|
||||
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
|
||||
}
|
||||
})
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
{"--vision-gemma-12b-default"},
|
||||
string_format("use Gemma 3 12B QAT (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gemma-3-12b-it-qat-GGUF";
|
||||
params.port = 8014;
|
||||
params.n_ctx = 0;
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
if (is_arguments_path({})) {
|
||||
// Entire JSON is the arguments and was parsed fully.
|
||||
return consume_json_result {
|
||||
partial->json.dump(),
|
||||
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
@@ -444,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
std::vector<std::string> path;
|
||||
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||
if (is_arguments_path(path)) {
|
||||
auto arguments = j.dump();
|
||||
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
|
||||
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||
if (idx != std::string::npos) {
|
||||
|
||||
+1
-1
@@ -426,7 +426,7 @@ struct common_params {
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||
int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc.
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
@@ -168,6 +169,47 @@ bool common_json_parse(
|
||||
}
|
||||
}
|
||||
|
||||
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
||||
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
||||
|
||||
auto is_high_surrogate = [&](const std::string & s) {
|
||||
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
||||
return s.length() >= 4 &&
|
||||
s[0] == '\\' && s[1] == 'u' &&
|
||||
std::tolower(s[2]) == 'd' &&
|
||||
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
||||
};
|
||||
|
||||
// Initialize the unicode marker to a low surrogate to handle the edge case
|
||||
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
||||
// backslash (\)
|
||||
std::string unicode_marker_padding = "udc00";
|
||||
std::smatch last_unicode_seq;
|
||||
|
||||
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
||||
std::smatch second_last_seq;
|
||||
std::string prelude = str.substr(0, last_unicode_seq.position());
|
||||
|
||||
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
||||
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
||||
|
||||
if (is_high_surrogate(last_unicode_seq.str())) {
|
||||
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
||||
unicode_marker_padding += "\\udc00";
|
||||
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
||||
if (is_high_surrogate(second_last_seq.str())) {
|
||||
// If this follows a high surrogate, pad it to be a low surrogate
|
||||
if (last_unicode_seq.length() == 2) {
|
||||
unicode_marker_padding = "dc00";
|
||||
} else if (last_unicode_seq.length() == 3) {
|
||||
unicode_marker_padding = "c00";
|
||||
} else {
|
||||
// The original unicode_marker_padding is already padded with 0s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
@@ -186,6 +228,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an object value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
@@ -205,6 +250,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an array value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
@@ -230,6 +278,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
||||
// Was inside an object key string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
|
||||
+2
-10
@@ -5966,20 +5966,12 @@ class Mamba2Model(TextModel):
|
||||
class JambaModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||
|
||||
def get_vocab_base_pre(self, tokenizer) -> str:
|
||||
del tokenizer # unused
|
||||
|
||||
return "gpt-2"
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.model").is_file():
|
||||
# Using Jamba's tokenizer.json causes errors on model load
|
||||
# (something about "byte not found in vocab"),
|
||||
# but there's a working tokenizer.model
|
||||
self._set_vocab_sentencepiece()
|
||||
else:
|
||||
# Some Jamba models only have a tokenizer.json, which works.
|
||||
self._set_vocab_gpt2()
|
||||
self._set_vocab_llama_hf()
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
|
||||
|
||||
+10
-8
@@ -31,7 +31,7 @@ Legend:
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -51,7 +51,7 @@ Legend:
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
@@ -65,11 +65,11 @@ Legend:
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
@@ -92,9 +92,9 @@ Legend:
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -102,9 +102,11 @@ Legend:
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
+12095
-4249
File diff suppressed because it is too large
Load Diff
@@ -894,14 +894,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get or expand a cached float32 tensor filled with a scalar value.
|
||||
* @brief Get or expand a cached tensor filled with a scalar value.
|
||||
*
|
||||
* This function manages cached device memory for float32 tensors. If the current
|
||||
* This function manages cached device memory for tensors. If the current
|
||||
* cache size is insufficient for the requested tensor shape, the old memory will
|
||||
* be released and new memory will be allocated. The allocated buffer is then
|
||||
* initialized either with zeros (when @p value == 0.0f) or with the given scalar
|
||||
* value using CANN operations. Finally, an aclTensor object is created from the
|
||||
* cached memory and returned.
|
||||
* be released and new memory will be allocated. The allocated buffer is
|
||||
* initialized with the given scalar value using CANN operations.
|
||||
* Finally, an aclTensor object is created from the cached memory and returned.
|
||||
*
|
||||
* @param ctx The CANN backend context that manages device memory.
|
||||
* @param buffer A pointer to the cached device buffer (will be allocated
|
||||
@@ -910,17 +909,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
||||
* updated when the cache is expanded.
|
||||
* @param ne The tensor shape array (number of elements in each dimension).
|
||||
* @param nb The stride size for each dimension.
|
||||
* @param dtype Data type of cached tensor.
|
||||
* @param dims The number of tensor dimensions.
|
||||
* @param value The scalar value used to fill the tensor (supports zero
|
||||
* initialization via memset or arbitrary values via fill_scalar).
|
||||
* @return An aclTensor pointer created from the cached buffer.
|
||||
*/
|
||||
static aclTensor* get_f32_cache_acl_tensor(
|
||||
static aclTensor* get_cache_acl_tensor(
|
||||
ggml_backend_cann_context& ctx,
|
||||
void** buffer,
|
||||
int64_t &cache_element,
|
||||
int64_t* ne,
|
||||
size_t* nb,
|
||||
ggml_type dtype,
|
||||
int64_t dims,
|
||||
float value) {
|
||||
// Calculate total number of elements
|
||||
@@ -928,7 +929,7 @@ static aclTensor* get_f32_cache_acl_tensor(
|
||||
for (int i = 0; i < dims; i++) {
|
||||
n_element *= ne[i];
|
||||
}
|
||||
size_t size = n_element * sizeof(float);
|
||||
size_t size = n_element * ggml_type_size(dtype);
|
||||
|
||||
// Allocate or expand cache if needed
|
||||
if (cache_element < n_element) {
|
||||
@@ -941,19 +942,17 @@ static aclTensor* get_f32_cache_acl_tensor(
|
||||
cache_element = n_element;
|
||||
|
||||
// Initialize cache
|
||||
if (value == 0.0f) {
|
||||
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
|
||||
} else {
|
||||
int64_t pool_ne[1] = { n_element };
|
||||
size_t pool_nb[1] = { sizeof(float) };
|
||||
aclTensor* acl_value = ggml_cann_create_tensor(
|
||||
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
|
||||
aclnn_fill_scalar(ctx, 1, acl_value);
|
||||
ggml_cann_release_resources(ctx, acl_value);
|
||||
}
|
||||
int64_t pool_ne[1] = { n_element };
|
||||
size_t pool_nb[1] = { ggml_type_size(dtype) };
|
||||
aclTensor* acl_value = ggml_cann_create_tensor(
|
||||
*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype),
|
||||
pool_ne, pool_nb, 1);
|
||||
aclnn_fill_scalar(ctx, value, acl_value);
|
||||
ggml_cann_release_resources(ctx, acl_value);
|
||||
}
|
||||
|
||||
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
|
||||
return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype),
|
||||
ggml_type_size(dtype), ne, nb, dims);
|
||||
}
|
||||
|
||||
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
@@ -965,35 +964,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
// build gamma, one...
|
||||
// build gamma.
|
||||
size_t acl_gamma_nb[GGML_MAX_DIMS];
|
||||
acl_gamma_nb[0] = sizeof(float);
|
||||
// gamma's type is the same with dst.
|
||||
acl_gamma_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
|
||||
aclTensor* acl_gamma = get_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.rms_norm_one_tensor_cache.cache,
|
||||
ctx.rms_norm_one_tensor_cache.size,
|
||||
src->ne,
|
||||
acl_gamma_nb,
|
||||
dst->type,
|
||||
1, // dims
|
||||
1.0f // value
|
||||
);
|
||||
|
||||
// build rstd, zero...
|
||||
// build rstd.
|
||||
int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
|
||||
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
|
||||
// rstd will always be F32.
|
||||
acl_rstd_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
|
||||
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
|
||||
aclTensor* acl_rstd = get_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.rms_norm_zero_tensor_cache.cache,
|
||||
ctx.rms_norm_zero_tensor_cache.size,
|
||||
acl_rstd_ne,
|
||||
acl_rstd_nb,
|
||||
GGML_TYPE_F32,
|
||||
GGML_MAX_DIMS - 1,
|
||||
0.0f // value
|
||||
);
|
||||
@@ -1765,33 +1768,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src0 = dst->src[0]; // src
|
||||
ggml_tensor* src1 = dst->src[1]; // index
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_F16: {
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * sizeof(float));
|
||||
void* src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
if(src0->type == dst->type) {
|
||||
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
} else {
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
|
||||
void* src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = dst->nb[0];
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
|
||||
}
|
||||
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
|
||||
src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_Q8_0: {
|
||||
// add 1 dim for bcast mul.
|
||||
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1],
|
||||
@@ -1799,7 +1804,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1],
|
||||
*dequant_ne;
|
||||
int64_t scale_offset = 0;
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
weight_ne[0] = QK8_0;
|
||||
weight_ne[1] = src0->ne[0] / QK8_0;
|
||||
@@ -1809,7 +1813,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
weight_ne[i] = src0->ne[i - 1];
|
||||
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
|
||||
}
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,1]
|
||||
scale_ne[0] = 1;
|
||||
scale_ne[1] = src0->ne[0] / QK8_0;
|
||||
@@ -1819,18 +1822,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
scale_ne[i] = src0->ne[i - 1];
|
||||
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
|
||||
}
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
dequant_ne = weight_ne;
|
||||
dequant_nb[0] = sizeof(float);
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
|
||||
}
|
||||
|
||||
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
|
||||
ggml_cann_pool_alloc dequant_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * sizeof(float));
|
||||
|
||||
ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type));
|
||||
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
|
||||
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
|
||||
GGML_MAX_DIMS + 1);
|
||||
@@ -1838,16 +1838,14 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
|
||||
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
|
||||
aclTensor* dequant_tensor = ggml_cann_create_tensor(
|
||||
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
|
||||
dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
|
||||
|
||||
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
|
||||
dequant_nb[0] = sizeof(float);
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
dequant_ne = src0->ne;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
|
||||
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(),
|
||||
dequant_ne, dequant_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
@@ -1965,16 +1963,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (weight_to_nz && is_matmul_weight(weight)) {
|
||||
int64_t acl_stride[2] = {1, transpose_ne[1]};
|
||||
|
||||
// Reverse ne.
|
||||
std::reverse(transpose_ne, transpose_ne + n_dims);
|
||||
|
||||
std::vector<int64_t> storageDims = {transpose_ne[0], transpose_ne[1]};
|
||||
|
||||
acl_weight_tensor = aclCreateTensor(
|
||||
transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
|
||||
0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
|
||||
acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
|
||||
} else {
|
||||
acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
|
||||
@@ -3178,7 +3168,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
aclTensor* acl_src0_f16_tensor = nullptr;
|
||||
aclTensor* acl_src1_f16_tensor = nullptr;
|
||||
aclTensor* acl_src2_f16_tensor = nullptr;
|
||||
aclTensor* acl_dst_f16_tensor = nullptr;
|
||||
|
||||
// Step 1: cast the src0 (Query) to fp16 if needed
|
||||
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
|
||||
@@ -3216,22 +3205,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
|
||||
src2_bsnd_nb, GGML_MAX_DIMS);
|
||||
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
void* out_f16_buffer = out_f16_allocator.alloc(
|
||||
ggml_nelements(dst) * faElemSize);
|
||||
|
||||
int64_t* out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
||||
}
|
||||
|
||||
acl_dst_f16_tensor = ggml_cann_create_tensor(
|
||||
out_f16_buffer, faDataType, faElemSize,
|
||||
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
|
||||
);
|
||||
|
||||
// Step 3: create the PSEShift tensor if needed
|
||||
// this tensor is considered as mask (f16) in the llama.cpp
|
||||
aclTensor* bcast_pse_tensor = nullptr;
|
||||
@@ -3334,8 +3307,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
int64_t keyAntiquantMode = 0;
|
||||
int64_t valueAntiquantMode = 0;
|
||||
|
||||
// Step 5: launch the FusedInferAttentionScoreV2 kernel.
|
||||
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
aclTensor * fa_dst_tensor = nullptr;
|
||||
aclTensor * acl_dst_tensor = nullptr;
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
void* out_f16_buffer = out_f16_allocator.alloc(
|
||||
ggml_nelements(dst) * faElemSize);
|
||||
|
||||
int64_t* out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
||||
}
|
||||
|
||||
fa_dst_tensor = ggml_cann_create_tensor(
|
||||
out_f16_buffer, faDataType, faElemSize,
|
||||
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
|
||||
);
|
||||
}
|
||||
else {
|
||||
fa_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
}
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
|
||||
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
|
||||
@@ -3357,23 +3351,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
blockSize, antiquantMode, // blockSize, antiquantMode
|
||||
softmaxLseFlag, // softmaxLseFlag
|
||||
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
|
||||
acl_dst_f16_tensor, // attentionOut
|
||||
fa_dst_tensor, // attentionOut
|
||||
nullptr // softmaxLse
|
||||
);
|
||||
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
// TODO: when dst is fp16, don't need cast
|
||||
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
||||
acl_src1_f16_tensor,
|
||||
acl_src2_f16_tensor,
|
||||
acl_dst_f16_tensor,
|
||||
acl_dst_tensor);
|
||||
if(src3 != nullptr){
|
||||
ggml_cann_release_resources(ctx, bcast_pse_tensor);
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
}
|
||||
}else{
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
||||
acl_src1_f16_tensor,
|
||||
acl_src2_f16_tensor,
|
||||
fa_dst_tensor,
|
||||
acl_dst_tensor,
|
||||
bcast_pse_tensor);
|
||||
|
||||
} else {
|
||||
GGML_ABORT("Function is not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -463,9 +463,9 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
float val = x[i] - mean;
|
||||
y[i] = val;
|
||||
val *= val;
|
||||
sum += (ggml_float)val;
|
||||
y[i] = val;
|
||||
}
|
||||
return sum/n;
|
||||
}
|
||||
|
||||
@@ -144,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
|
||||
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
|
||||
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
|
||||
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
|
||||
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
|
||||
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
|
||||
@@ -160,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
|
||||
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
|
||||
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
@@ -820,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f
|
||||
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
|
||||
inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
|
||||
const float v = GGML_CPU_FP16_TO_FP32(x[i]);
|
||||
y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
|
||||
}
|
||||
}
|
||||
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
||||
|
||||
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
|
||||
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
||||
|
||||
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||
|
||||
@@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
|
||||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
|
||||
return GGML_CUDA_CC_IS_AMD(cc) ||
|
||||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
@@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
|
||||
}
|
||||
|
||||
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
|
||||
// Important: do not use this function if dst and src both point at registers.
|
||||
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
|
||||
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
|
||||
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
|
||||
template <int nbytes, int alignment = 0>
|
||||
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
||||
if constexpr (alignment != 0) {
|
||||
|
||||
@@ -793,8 +793,6 @@ void launch_fattn(
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
@@ -878,7 +876,7 @@ void launch_fattn(
|
||||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
// multiple sequences of possibly different lengths.
|
||||
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
const int s31 = mask->nb[1] / sizeof(half2);
|
||||
const int s33 = mask->nb[3] / sizeof(half2);
|
||||
|
||||
@@ -916,8 +914,7 @@ void launch_fattn(
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
||||
} else {
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
@@ -946,7 +943,7 @@ void launch_fattn(
|
||||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = Q->ne[2]*Q->ne[3];
|
||||
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
|
||||
@@ -1,756 +1,45 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile.cuh"
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
|
||||
// kq_stride == number of KQ rows to process per iteration
|
||||
// kq_nbatch == number of K columns to load in parallel for KQ calculation
|
||||
|
||||
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
if (fast_fp16_available(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
GGML_UNUSED(warp_size);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
#ifdef RDNA
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // RDNA
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
case 256:
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
case 256:
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
return 128;
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED_VARS(cc, ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
|
||||
#ifdef RDNA
|
||||
return 3;
|
||||
#else
|
||||
return ncols <= 16 ? 3 : 2;
|
||||
#endif // RDNA
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
|
||||
static __global__ void flash_attn_tile(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef GGML_USE_WMMA_FATTN
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // GGML_USE_WMMA_FATTN
|
||||
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = 32;
|
||||
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
|
||||
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
|
||||
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
|
||||
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
|
||||
static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
|
||||
|
||||
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
constexpr int cpw = ncols/nwarps; // cols per warp
|
||||
|
||||
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
|
||||
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
|
||||
|
||||
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ half2 Q_tmp[ncols][D/2];
|
||||
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#else
|
||||
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
|
||||
|
||||
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ float Q_tmp[ncols][D];
|
||||
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
|
||||
|
||||
float KQ_max[cpw];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
}
|
||||
float KQ_sum[cpw] = {0.0f};
|
||||
|
||||
// Load Q data, convert to FP16 if fast.
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
const int j = j0 + threadIdx.y*cpw;
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
float tmp_f[cpy_ne_D] = {0.0f};
|
||||
if (ic0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp_f[i1] *= scale;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
||||
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Main loop over KV cache:
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float KQ_max_new[cpw];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < cpw; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
|
||||
|
||||
// KQ = K @ Q matrix multiplication:
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
|
||||
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
half2 tmp_h2[cpy_ne_kqnb/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_kqnb/2];
|
||||
#pragma unroll
|
||||
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
|
||||
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
||||
half2 K_k[kq_stride/warp_size][cpy_ne];
|
||||
half2 Q_k[cpw][cpy_ne];
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
||||
float K_k[kq_stride/warp_size][cpy_ne];
|
||||
float Q_k[cpw][cpy_ne];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < cpy_ne; ++k) {
|
||||
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (k_KQ_0 + kq_nbatch < D) {
|
||||
__syncthreads(); // Sync not needed on last iteration.
|
||||
}
|
||||
}
|
||||
|
||||
// Apply logit softcap, mask, update KQ_max:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (use_logit_softcap) {
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#else
|
||||
float tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
|
||||
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
|
||||
KQ_max[j0+j1] = KQ_max_new[j0+j1];
|
||||
|
||||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/warp_size][j1] = val;
|
||||
}
|
||||
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
|
||||
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
|
||||
}
|
||||
}
|
||||
|
||||
// VKQ = V @ KQ matrix multiplication:
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
|
||||
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
|
||||
const int k_tile = k1 + threadIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
|
||||
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
||||
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
half2 V_k[(D/2)/warp_size];
|
||||
half2 KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
half tmp[softmax_iter_j];
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
|
||||
&tmp, KQ[j][k0 + k1]);
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_k[j0+j1] = __half2half2(tmp[j1]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
|
||||
&KQ_k[j0], KQ[j][k0 + k1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
|
||||
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Attention sink: adjust running max and sum once per head
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
|
||||
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
|
||||
KQ_max[j0] = KQ_max_new_j;
|
||||
|
||||
const float val = expf(sink - KQ_max[j0]);
|
||||
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
KQ_sum[j0] += val;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
if (gridDim.y == 1) {
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
|
||||
}
|
||||
#else
|
||||
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
|
||||
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
// Write back results:
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
float2 tmp[cpy_ne_D];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (D <= 128) {
|
||||
if (Q->ne[1] > 32) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
if (Q->ne[1] > 16) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
}
|
||||
|
||||
template <bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
switch (K->ne[0]) {
|
||||
case 40: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
|
||||
} break;
|
||||
case 64: {
|
||||
launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
|
||||
} break;
|
||||
case 80: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
|
||||
} break;
|
||||
case 96: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
|
||||
} break;
|
||||
case 112: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
|
||||
} break;
|
||||
case 128: {
|
||||
launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
|
||||
} break;
|
||||
case 256: {
|
||||
launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("Unsupported head size");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
||||
|
||||
+41
-35
@@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
#endif// FLASH_ATTN_AVAILABLE
|
||||
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
@@ -206,37 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
// The effective batch size for the kernel can be increased by gqa_ratio.
|
||||
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
|
||||
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
// TODO: temporary until support is extended
|
||||
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
|
||||
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
switch (K->ne[0]) {
|
||||
case 40:
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 80:
|
||||
case 96:
|
||||
case 128:
|
||||
case 112:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 576:
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
@@ -270,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
|
||||
|
||||
// If Turing tensor cores available, use them except for some cases with batch size 1:
|
||||
if (turing_mma_available(cc)) {
|
||||
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// If Turing tensor cores available, use them:
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
if (Q->ne[1] <= 2) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
|
||||
if (!gqa_opt_applies && Q->ne[1] == 1) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
|
||||
return best;
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
// Use kernels specialized for small batch sizes if possible:
|
||||
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
||||
// For large batch sizes, use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc)) {
|
||||
// Use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
|
||||
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
}
|
||||
|
||||
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (Q->ne[1] == 1) {
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
return BEST_FATTN_KERNEL_TILE;
|
||||
}
|
||||
|
||||
|
||||
@@ -3867,7 +3867,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(112, 112);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(128, 128);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(256, 256);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(40, 40);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(576, 512);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(64, 64);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(80, 80);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(96, 96);
|
||||
@@ -3,8 +3,17 @@
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
|
||||
|
||||
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
@@ -51,6 +60,11 @@ def get_short_name(long_quant_name):
|
||||
for filename in glob("*.cu"):
|
||||
os.remove(filename)
|
||||
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type_k in TYPES_KV:
|
||||
for type_v in TYPES_KV:
|
||||
with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
@@ -64,7 +78,9 @@ for ncols in [8, 16, 32, 64]:
|
||||
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
|
||||
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
if head_size_kq == 40:
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 == 16:
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 != 16:
|
||||
|
||||
@@ -53,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
|
||||
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
|
||||
|
||||
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
|
||||
@@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_SUM);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
||||
|
||||
@@ -1482,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_OPT_STEP_SGD);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
@@ -134,6 +135,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
|
||||
@@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@@ -798,6 +799,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return false;
|
||||
};
|
||||
}
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return has_simdgroup_reduction;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -544,6 +544,10 @@ typedef struct{
|
||||
float limit;
|
||||
} ggml_metal_kargs_glu;
|
||||
|
||||
typedef struct {
|
||||
uint64_t np;
|
||||
} ggml_metal_kargs_sum;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
@@ -773,4 +777,12 @@ typedef struct {
|
||||
uint64_t nb01;
|
||||
} ggml_metal_kargs_argmax;
|
||||
|
||||
typedef struct {
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_opt_step_adamw;
|
||||
|
||||
typedef struct {
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_opt_step_sgd;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
|
||||
@@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_glu(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SUM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_sum(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
@@ -410,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_argmax(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
{
|
||||
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
{
|
||||
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
||||
@@ -840,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
|
||||
|
||||
ggml_metal_kargs_sum args = {
|
||||
/*.np =*/ n,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -1546,9 +1582,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
!ggml_is_transposed(op->src[1]) &&
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
||||
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
|
||||
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
|
||||
//GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
@@ -3402,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_adamw args = {
|
||||
/*.np =*/ np,
|
||||
};
|
||||
|
||||
int ida = 0;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
const int64_t n = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_sgd args = {
|
||||
/*.np =*/ np,
|
||||
};
|
||||
|
||||
int ida = 0;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
const int64_t n = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
@@ -78,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_op_sum_f32(
|
||||
constant ggml_metal_kargs_sum & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||
|
||||
if (tiitg != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
float acc = 0.0f;
|
||||
for (ulong i = 0; i < args.np; ++i) {
|
||||
acc += src0[i];
|
||||
}
|
||||
|
||||
dst[0] = acc;
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
@@ -7487,7 +7505,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
||||
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
template<int NR0, typename args_t>
|
||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7500,13 +7518,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK4_NL;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
const int first_row = (r0 * NSG + sgitg) * NR0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
@@ -7517,6 +7534,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int nb = args.ne00/QK4_NL;
|
||||
const int ns01 = args.nb01/args.nb00;
|
||||
|
||||
const short ix = tiisg/2; // 0...15
|
||||
const short it = tiisg%2; // 0 or 1
|
||||
|
||||
@@ -7524,24 +7544,25 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float sumf[nr0]={0.f};
|
||||
float sumf[NR0]={0.f};
|
||||
|
||||
device const float * yb = y + ix * QK4_NL + it * 8;
|
||||
device const float * yb = y + ix*QK4_NL + it*8;
|
||||
|
||||
uint32_t aux32[2];
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
||||
|
||||
float4 qf1, qf2;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 16) {
|
||||
// [TAG_MUL_MV_WEIRD]
|
||||
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0];
|
||||
yl[1] = y4[4];
|
||||
yl[2] = y4[1];
|
||||
yl[3] = y4[5];
|
||||
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
device const block_iq4_nl & xb = x[row*nb + ib];
|
||||
for (short row = 0; row < NR0; row++) {
|
||||
device const block_iq4_nl & xb = x[row*ns01 + ib];
|
||||
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
@@ -7572,7 +7593,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
||||
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
||||
float sum_all = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
@@ -7594,7 +7615,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
||||
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
template<int NR0, typename args_t>
|
||||
void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7607,12 +7628,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK_K;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
const int first_row = (r0 * NSG + sgitg) * NR0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
@@ -7623,6 +7643,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int nb = args.ne00/QK_K;
|
||||
const int ns01 = args.nb01/args.nb00;
|
||||
|
||||
const short ix = tiisg/16; // 0 or 1
|
||||
const short it = tiisg%16; // 0...15
|
||||
const short ib = it/2;
|
||||
@@ -7632,7 +7655,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float sumf[nr0]={0.f};
|
||||
float sumf[NR0]={0.f};
|
||||
|
||||
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
||||
|
||||
@@ -7641,15 +7664,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
|
||||
float4 qf1, qf2;
|
||||
|
||||
for (int ibl = ix; ibl < nb; ibl += 2) {
|
||||
// [TAG_MUL_MV_WEIRD]
|
||||
for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0];
|
||||
yl[1] = y4[4];
|
||||
yl[2] = y4[1];
|
||||
yl[3] = y4[5];
|
||||
|
||||
for (short row = 0; row < nr0; ++row) {
|
||||
device const block_iq4_xs & xb = x[row*nb + ibl];
|
||||
for (short row = 0; row < NR0; ++row) {
|
||||
device const block_iq4_xs & xb = x[row*ns01 + ibl];
|
||||
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
||||
|
||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||
@@ -7679,7 +7703,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
||||
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
||||
float sum_all = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
@@ -7701,7 +7725,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
||||
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
template<int NR0, typename args_t>
|
||||
void kernel_mul_mv_mxfp4_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
@@ -7714,13 +7738,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
const int nb = args.ne00/QK_MXFP4;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
const int first_row = (r0 * NSG + sgitg) * NR0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
@@ -7731,6 +7754,9 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
||||
device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int nb = args.ne00/QK_MXFP4;
|
||||
const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
|
||||
|
||||
const short ix = tiisg/2; // 0...15
|
||||
const short it = tiisg%2; // 0 or 1
|
||||
|
||||
@@ -7738,20 +7764,22 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float4 yl[4];
|
||||
float sumf[nr0]={0.f};
|
||||
float sumf[NR0]={0.f};
|
||||
|
||||
device const float * yb = y + ix * QK_MXFP4 + it * 8;
|
||||
device const float * yb = y + ix*QK_MXFP4 + it*8;
|
||||
|
||||
// note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
|
||||
// no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
|
||||
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
|
||||
device const float4 * y4 = (device const float4 *) yb;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 16) {
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0];
|
||||
yl[1] = y4[4];
|
||||
yl[2] = y4[1];
|
||||
yl[3] = y4[5];
|
||||
|
||||
#pragma unroll(nr0)
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
device const block_mxfp4 & xb = x[row*nb + ib];
|
||||
FOR_UNROLL (short row = 0; row < NR0; row++) {
|
||||
device const block_mxfp4 & xb = x[row*ns01 + ib];
|
||||
device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
|
||||
|
||||
float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
|
||||
@@ -7769,7 +7797,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
||||
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
||||
float sum_all = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
@@ -8744,3 +8772,51 @@ kernel void kernel_pool_2d_avg_f32(
|
||||
|
||||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_adamw_f32(
|
||||
constant ggml_metal_kargs_opt_step_adamw & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device float * g_m,
|
||||
device float * g_v,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float alpha = pars[0];
|
||||
const float beta1 = pars[1];
|
||||
const float beta2 = pars[2];
|
||||
const float eps = pars[3];
|
||||
const float wd = pars[4];
|
||||
const float beta1h = pars[5];
|
||||
const float beta2h = pars[6];
|
||||
|
||||
const float gi = g[gid];
|
||||
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
|
||||
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
|
||||
|
||||
g_m[gid] = gmi;
|
||||
g_v[gid] = gvi;
|
||||
|
||||
const float mh = gmi * beta1h;
|
||||
const float vh = sqrt(gvi * beta2h) + eps;
|
||||
|
||||
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_sgd_f32(
|
||||
constant ggml_metal_kargs_opt_step_sgd & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND)
|
||||
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
|
||||
|
||||
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "concat.hpp"
|
||||
#include "conv.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "count-equal.hpp"
|
||||
#include "cpy.hpp"
|
||||
#include "dequantize.hpp"
|
||||
#include "dmmv.hpp"
|
||||
@@ -28,6 +29,7 @@
|
||||
#include "mmvq.hpp"
|
||||
#include "norm.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "pad.hpp"
|
||||
#include "quantize.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "rope.hpp"
|
||||
|
||||
@@ -303,10 +303,6 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
@@ -332,11 +328,6 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_sub(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_count_equal(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_mul(ctx, dst);
|
||||
|
||||
@@ -16,12 +16,6 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float op_count_equal(const float a, const float b) {
|
||||
return (a == b) ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
static __dpct_inline__ float op_mul(const float a, const float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
@@ -195,7 +195,8 @@ struct optimize_feature {
|
||||
|
||||
struct sycl_device_info {
|
||||
int cc; // compute capability
|
||||
// int nsm; // number of streaming multiprocessors
|
||||
int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
|
||||
// number of compute units on a SYCL device.
|
||||
// size_t smpb; // max. shared memory per block
|
||||
size_t smpbo; // max. shared memory per block (with opt-in)
|
||||
bool vmm; // virtual memory support
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
#include "count-equal.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
template <typename T>
|
||||
static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
|
||||
int64_t *__restrict__ dst, const int64_t dk,
|
||||
const int64_t k) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk;
|
||||
const int64_t i1 = sycl::min(i0 + dk, k);
|
||||
|
||||
int nequal = 0;
|
||||
|
||||
for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) {
|
||||
const T xi = x[i];
|
||||
const T yi = y[i];
|
||||
nequal += xi == yi;
|
||||
}
|
||||
|
||||
nequal = warp_reduce_sum(nequal);
|
||||
|
||||
if (item_ct1.get_local_id(2) != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
|
||||
(int *)dst, nequal);
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_I64);
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
int64_t * dst_d = (int64_t *) dst->data;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
const int id = get_current_device_id();
|
||||
const int nsm = ggml_sycl_info().devices[id].nsm;
|
||||
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
|
||||
const int64_t dne =
|
||||
GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE);
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst))));
|
||||
|
||||
const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dpct::dim3 block_nums(
|
||||
std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) /
|
||||
SYCL_COUNT_EQUAL_CHUNK_SIZE),
|
||||
1, 1);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_I32: {
|
||||
const int *src0_d = (const int *)src0->data;
|
||||
const int *src1_d = (const int *)src1->data;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
count_equal(src0_d, src1_d, dst_d, dne, ne);
|
||||
GGML_UNUSED(item_ct1);
|
||||
});
|
||||
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
#ifndef GGML_SYCL_COUNT_EQUAL_HPP
|
||||
#define GGML_SYCL_COUNT_EQUAL_HPP
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif //GGML_SYCL_COUNT_EQUAL_HPP
|
||||
@@ -328,26 +328,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
||||
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
|
||||
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
||||
item_ct1.get_group(0) * ne00 * ne01;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
dst[offset_dst] = static_cast<T>(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
|
||||
const sycl::nd_item<1> &item_ct1) {
|
||||
@@ -431,18 +411,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void pad_sycl(const T *x, T *dst, const int ne00,
|
||||
const int ne01, const int ne02, const int ne0,
|
||||
const int ne1, const int ne2, queue_ptr stream) {
|
||||
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
|
||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
@@ -596,40 +564,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
||||
}
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::half>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
||||
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
||||
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ggml_sycl_detail
|
||||
|
||||
@@ -919,14 +853,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
|
||||
queue_ptr stream) {
|
||||
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
float min_val;
|
||||
float max_val;
|
||||
@@ -1119,10 +1045,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_upscale(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_pad(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
|
||||
@@ -67,8 +67,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -85,9 +85,11 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
||||
|
||||
info.devices[i].cc =
|
||||
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
||||
info.devices[i].nsm = prop.get_max_compute_units();
|
||||
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
||||
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
||||
info.devices[i].smpbo = prop.get_local_mem_size();
|
||||
|
||||
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
||||
}
|
||||
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
@@ -1512,60 +1514,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
|
||||
template <ggml_sort_order order>
|
||||
__dpct_inline__ static void
|
||||
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
|
||||
const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
|
||||
const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
|
||||
uint8_t *dpct_local) {
|
||||
// bitonic sort
|
||||
int col = item_ct1.get_local_id(2);
|
||||
int col_index = item_ct1.get_local_id(2);
|
||||
int row = item_ct1.get_group(1);
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
for (int i = 0; i < tasks_per_thread; i++) {
|
||||
int col = col_index * tasks_per_thread + i;
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const float * x_row = x + row * ncols;
|
||||
auto dst_row = (int *)dpct_local;
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
for (int i=0;i<tasks_per_thread;i++){
|
||||
int col = col_index*tasks_per_thread+i;
|
||||
dst_row[col] = col;
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
for (int i = 0; i < tasks_per_thread; i++) {
|
||||
int col = col_index * tasks_per_thread + i;
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols &&
|
||||
(order == GGML_SORT_ORDER_ASC
|
||||
? x_row[dst_row[col]] > x_row[dst_row[ixj]]
|
||||
: x_row[dst_row[col]] <
|
||||
x_row[dst_row[ixj]]))) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols &&
|
||||
(order == GGML_SORT_ORDER_ASC
|
||||
? x_row[dst_row[col]] < x_row[dst_row[ixj]]
|
||||
: x_row[dst_row[col]] >
|
||||
x_row[dst_row[ixj]]))) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
/*
|
||||
DPCT1118:1: SYCL group functions and algorithms must be encountered
|
||||
in converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
for (int i = 0; i < tasks_per_thread; i++) {
|
||||
int col = col_index * tasks_per_thread + i;
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
@@ -1738,11 +1750,20 @@ static int next_power_of_2(int x) {
|
||||
|
||||
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
const int nrows, ggml_sort_order order,
|
||||
queue_ptr stream) {
|
||||
queue_ptr stream, int device) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, ncols_pad);
|
||||
int nth = 1;
|
||||
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
while (nth < ncols_pad && nth < max_block_size)
|
||||
nth *= 2;
|
||||
if (nth > max_block_size)
|
||||
nth = max_block_size;
|
||||
|
||||
const int tasks_per_thread = ncols_pad / nth;
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, nth);
|
||||
const sycl::range<3> block_nums(1, nrows, 1);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
|
||||
@@ -1755,8 +1776,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
||||
x, dst, ncols, ncols_pad, item_ct1,
|
||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
||||
dpct_local_acc_ct1
|
||||
.get_multi_ptr<sycl::access::decorated::no>()
|
||||
.get());
|
||||
});
|
||||
});
|
||||
@@ -1769,8 +1791,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
||||
x, dst, ncols, ncols_pad, item_ct1,
|
||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
||||
dpct_local_acc_ct1
|
||||
.get_multi_ptr<sycl::access::decorated::no>()
|
||||
.get());
|
||||
});
|
||||
});
|
||||
@@ -2142,7 +2165,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||
|
||||
argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
|
||||
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
|
||||
main_stream, ctx.device);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
@@ -4413,8 +4437,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_PAD:
|
||||
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
//#include "common.hpp"
|
||||
#include "pad.hpp"
|
||||
|
||||
static void pad_f32(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
int i0 = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
int i1 = item_ct1.get_group(1);
|
||||
int i2 = item_ct1.get_group(0) % ne2;
|
||||
int i3 = item_ct1.get_group(0) / ne2;
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +
|
||||
i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static void pad_f32_sycl(const float *src, float *dst, const int lp0,
|
||||
const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3,
|
||||
const int rp3, const int ne0, const int ne1,
|
||||
const int ne2, const int ne3,
|
||||
dpct::queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
||||
dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
|
||||
ne2, ne3);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
||||
|
||||
pad_f32_sycl(src0_d, dst_d,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
}
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_pad(ctx, dst);
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_PAD_HPP
|
||||
#define GGML_SYCL_PAD_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_PAD_BLOCK_SIZE 256
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_PAD_HPP
|
||||
@@ -140,7 +140,11 @@ uint32_t llama_hparams::n_embd_s() const {
|
||||
}
|
||||
|
||||
bool llama_hparams::is_recurrent(uint32_t il) const {
|
||||
return recurrent_layer_arr[il];
|
||||
if (il < n_layer) {
|
||||
return recurrent_layer_arr[il];
|
||||
}
|
||||
|
||||
GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer);
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_pos_per_embd() const {
|
||||
|
||||
+4
-4
@@ -16313,10 +16313,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
|
||||
}
|
||||
|
||||
ggml_tensor * build_layer_ffn(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inpSA,
|
||||
const llama_model & model,
|
||||
const int il) {
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inpSA,
|
||||
const llama_model & model,
|
||||
const int il) {
|
||||
|
||||
// For Granite architectures - scale residual
|
||||
if (hparams.f_residual_scale) {
|
||||
|
||||
@@ -524,6 +524,64 @@ static void test_json_with_dumped_args() {
|
||||
R"({"foo": "bar", "args": {"arg1": [)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":["})"
|
||||
);
|
||||
|
||||
// Unicode tests
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u0)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u0"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u00)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u00"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u000)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u000"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\u0000)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\u0000"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud8)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud8"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud80)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud80"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\u)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\u"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\ud)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\ud"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc0)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc0"})"
|
||||
);
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "\ud800\udc00)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc00"})"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_positions() {
|
||||
|
||||
@@ -58,7 +58,7 @@ static void test_json_healing() {
|
||||
for (const auto & input : inputs) {
|
||||
common_json out;
|
||||
assert_equals(true, common_json_parse(input, "$foo", out));
|
||||
assert_equals<std::string>(expected, out.json.dump());
|
||||
assert_equals<std::string>(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true));
|
||||
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
|
||||
}
|
||||
};
|
||||
@@ -228,6 +228,56 @@ static void test_json_healing() {
|
||||
R"({"key":"$foo"})",
|
||||
R"(:"$foo)"
|
||||
);
|
||||
// Test unicode escape sequences
|
||||
test(
|
||||
{
|
||||
R"({"a":"\u)",
|
||||
},
|
||||
R"({"a":"\u0000$foo"})",
|
||||
R"(0000$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\u00)",
|
||||
},
|
||||
R"({"a":"\u0000$foo"})",
|
||||
R"(00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud300)",
|
||||
},
|
||||
R"({"a":"\ud300$foo"})",
|
||||
R"($foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(\udc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(udc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\u)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"(dc00$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a":"\ud800\udc00)",
|
||||
},
|
||||
R"({"a":"\ud800\udc00$foo"})",
|
||||
R"($foo)"
|
||||
);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
Binary file not shown.
+5
-11
@@ -4226,7 +4226,7 @@ struct server_context {
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
@@ -5401,15 +5401,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
// TODO: implement
|
||||
//int top_n = 1;
|
||||
//if (body.count("top_n") != 1) {
|
||||
// top_n = body.at("top_n");
|
||||
//} else {
|
||||
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
// return;
|
||||
//}
|
||||
|
||||
// if true, use TEI API format, otherwise use Jina API format
|
||||
// Jina: https://jina.ai/reranker/
|
||||
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
|
||||
@@ -5434,6 +5425,8 @@ int main(int argc, char ** argv) {
|
||||
return;
|
||||
}
|
||||
|
||||
int top_n = json_value(body, "top_n", (int)documents.size());
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
@@ -5474,7 +5467,8 @@ int main(int argc, char ** argv) {
|
||||
body,
|
||||
responses,
|
||||
is_tei_format,
|
||||
documents);
|
||||
documents,
|
||||
top_n);
|
||||
|
||||
res_ok(res, root);
|
||||
};
|
||||
|
||||
@@ -102,3 +102,45 @@ def test_rerank_usage(query, doc1, doc2, n_tokens):
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_n,expected_len", [
|
||||
(None, len(TEST_DOCUMENTS)), # no top_n parameter
|
||||
(2, 2),
|
||||
(4, 4),
|
||||
(99, len(TEST_DOCUMENTS)), # higher than available docs
|
||||
])
|
||||
def test_rerank_top_n(top_n, expected_len):
|
||||
global server
|
||||
server.start()
|
||||
data = {
|
||||
"query": "Machine learning is",
|
||||
"documents": TEST_DOCUMENTS,
|
||||
}
|
||||
if top_n is not None:
|
||||
data["top_n"] = top_n
|
||||
|
||||
res = server.make_request("POST", "/rerank", data=data)
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["results"]) == expected_len
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_n,expected_len", [
|
||||
(None, len(TEST_DOCUMENTS)), # no top_n parameter
|
||||
(2, 2),
|
||||
(4, 4),
|
||||
(99, len(TEST_DOCUMENTS)), # higher than available docs
|
||||
])
|
||||
def test_rerank_tei_top_n(top_n, expected_len):
|
||||
global server
|
||||
server.start()
|
||||
data = {
|
||||
"query": "Machine learning is",
|
||||
"texts": TEST_DOCUMENTS,
|
||||
}
|
||||
if top_n is not None:
|
||||
data["top_n"] = top_n
|
||||
|
||||
res = server.make_request("POST", "/rerank", data=data)
|
||||
assert res.status_code == 200
|
||||
assert len(res.body) == expected_len
|
||||
|
||||
+35
-38
@@ -849,47 +849,44 @@ static json format_response_rerank(
|
||||
const json & request,
|
||||
const json & ranks,
|
||||
bool is_tei_format,
|
||||
std::vector<std::string> & texts) {
|
||||
json res;
|
||||
if (is_tei_format) {
|
||||
// TEI response format
|
||||
res = json::array();
|
||||
bool return_text = json_value(request, "return_text", false);
|
||||
for (const auto & rank : ranks) {
|
||||
int index = json_value(rank, "index", 0);
|
||||
json elem = json{
|
||||
{"index", index},
|
||||
{"score", json_value(rank, "score", 0.0)},
|
||||
};
|
||||
if (return_text) {
|
||||
elem["text"] = std::move(texts[index]);
|
||||
}
|
||||
res.push_back(elem);
|
||||
}
|
||||
} else {
|
||||
// Jina response format
|
||||
json results = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
for (const auto & rank : ranks) {
|
||||
results.push_back(json{
|
||||
{"index", json_value(rank, "index", 0)},
|
||||
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||
});
|
||||
|
||||
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
||||
}
|
||||
|
||||
res = json{
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", "list"},
|
||||
{"usage", json{
|
||||
{"prompt_tokens", n_tokens},
|
||||
{"total_tokens", n_tokens}
|
||||
}},
|
||||
{"results", results}
|
||||
std::vector<std::string> & texts,
|
||||
int top_n) {
|
||||
int32_t n_tokens = 0;
|
||||
bool return_text = is_tei_format && json_value(request, "return_text", false);
|
||||
std::vector<json> elements; // Temporary vector to hold unsorted elements
|
||||
std::string score_label = is_tei_format ? "score" : "relevance_score";
|
||||
for (const auto & rank : ranks) {
|
||||
int index = json_value(rank, "index", 0);
|
||||
json elem = json{
|
||||
{"index", index},
|
||||
{score_label, json_value(rank, "score", 0.0)},
|
||||
};
|
||||
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
||||
if (return_text) {
|
||||
elem["text"] = std::move(texts[index]);
|
||||
}
|
||||
elements.push_back(elem);
|
||||
}
|
||||
|
||||
std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) {
|
||||
return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0);
|
||||
});
|
||||
|
||||
elements.resize(std::min(top_n, (int)elements.size()));
|
||||
json results = elements;
|
||||
|
||||
if (is_tei_format) return results;
|
||||
|
||||
json res = json{
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", "list"},
|
||||
{"usage", json{
|
||||
{"prompt_tokens", n_tokens},
|
||||
{"total_tokens", n_tokens}
|
||||
}},
|
||||
{"results", results}
|
||||
};
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
import { Check, X } from '@lucide/svelte';
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { ChatAttachmentsList } from '$lib/components/app';
|
||||
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
|
||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -55,6 +56,7 @@
|
||||
|
||||
let isMultiline = $state(false);
|
||||
let messageElement: HTMLElement | undefined = $state();
|
||||
const currentConfig = config();
|
||||
|
||||
$effect(() => {
|
||||
if (!messageElement || !message.content.trim()) return;
|
||||
@@ -123,9 +125,18 @@
|
||||
class="max-w-[80%] rounded-[1.125rem] bg-primary px-3.75 py-1.5 text-primary-foreground data-[multiline]:py-2.5"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
>
|
||||
<span bind:this={messageElement} class="text-md whitespace-pre-wrap">
|
||||
{message.content}
|
||||
</span>
|
||||
{#if currentConfig.renderUserContentAsMarkdown}
|
||||
<div bind:this={messageElement} class="text-md">
|
||||
<MarkdownContent
|
||||
class="markdown-user-content text-primary-foreground"
|
||||
content={message.content}
|
||||
/>
|
||||
</div>
|
||||
{:else}
|
||||
<span bind:this={messageElement} class="text-md whitespace-pre-wrap">
|
||||
{message.content}
|
||||
</span>
|
||||
{/if}
|
||||
</Card>
|
||||
{/if}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
ChatMessages,
|
||||
ChatProcessingInfo,
|
||||
EmptyFileAlertDialog,
|
||||
ChatErrorDialog,
|
||||
ServerErrorSplash,
|
||||
ServerInfo,
|
||||
ServerLoadingSplash,
|
||||
@@ -22,10 +23,11 @@
|
||||
activeMessages,
|
||||
activeConversation,
|
||||
deleteConversation,
|
||||
dismissErrorDialog,
|
||||
errorDialog,
|
||||
isLoading,
|
||||
sendMessage,
|
||||
stopGeneration,
|
||||
setMaxContextError
|
||||
stopGeneration
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import {
|
||||
supportsVision,
|
||||
@@ -34,7 +36,6 @@
|
||||
serverWarning,
|
||||
serverStore
|
||||
} from '$lib/stores/server.svelte';
|
||||
import { contextService } from '$lib/services';
|
||||
import { parseFilesToMessageExtras } from '$lib/utils/convert-files-to-extra';
|
||||
import { isFileTypeSupported } from '$lib/utils/file-type';
|
||||
import { filterFilesByModalities } from '$lib/utils/modality-file-validation';
|
||||
@@ -79,6 +80,7 @@
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
|
||||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
|
||||
async function handleDeleteConfirm() {
|
||||
@@ -105,6 +107,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleErrorDialogOpenChange(open: boolean) {
|
||||
if (!open) {
|
||||
dismissErrorDialog();
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
}
|
||||
@@ -183,21 +191,6 @@
|
||||
|
||||
const extras = result?.extras;
|
||||
|
||||
// Check context limit using real-time slots data
|
||||
const contextCheck = await contextService.checkContextLimit();
|
||||
|
||||
if (contextCheck && contextCheck.wouldExceed) {
|
||||
const errorMessage = contextService.getContextErrorMessage(contextCheck);
|
||||
|
||||
setMaxContextError({
|
||||
message: errorMessage,
|
||||
estimatedTokens: contextCheck.currentUsage,
|
||||
maxContext: contextCheck.maxContext
|
||||
});
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Enable autoscroll for user-initiated message sending
|
||||
userScrolledUp = false;
|
||||
autoScrollEnabled = true;
|
||||
@@ -461,6 +454,13 @@
|
||||
}}
|
||||
/>
|
||||
|
||||
<ChatErrorDialog
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? 'server'}
|
||||
/>
|
||||
|
||||
<style>
|
||||
.conversation-chat-form {
|
||||
position: relative;
|
||||
|
||||
@@ -80,6 +80,11 @@
|
||||
key: 'showModelInfo',
|
||||
label: 'Show model information',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'renderUserContentAsMarkdown',
|
||||
label: 'Render user content as Markdown',
|
||||
type: 'checkbox'
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
<script lang="ts">
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { AlertTriangle, TimerOff } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
type: 'timeout' | 'server';
|
||||
message: string;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
let { open = $bindable(), type, message, onOpenChange }: Props = $props();
|
||||
|
||||
const isTimeout = $derived(type === 'timeout');
|
||||
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
|
||||
const description = $derived(
|
||||
isTimeout
|
||||
? 'The request did not receive a response from the server before timing out.'
|
||||
: 'The server responded with an error message. Review the details below.'
|
||||
);
|
||||
const iconClass = $derived(isTimeout ? 'text-destructive' : 'text-amber-500');
|
||||
const badgeClass = $derived(
|
||||
isTimeout
|
||||
? 'border-destructive/40 bg-destructive/10 text-destructive'
|
||||
: 'border-amber-500/40 bg-amber-500/10 text-amber-600 dark:text-amber-400'
|
||||
);
|
||||
|
||||
function handleOpenChange(newOpen: boolean) {
|
||||
open = newOpen;
|
||||
onOpenChange?.(newOpen);
|
||||
}
|
||||
</script>
|
||||
|
||||
<AlertDialog.Root {open} onOpenChange={handleOpenChange}>
|
||||
<AlertDialog.Content>
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title class="flex items-center gap-2">
|
||||
{#if isTimeout}
|
||||
<TimerOff class={`h-5 w-5 ${iconClass}`} />
|
||||
{:else}
|
||||
<AlertTriangle class={`h-5 w-5 ${iconClass}`} />
|
||||
{/if}
|
||||
|
||||
{title}
|
||||
</AlertDialog.Title>
|
||||
|
||||
<AlertDialog.Description>
|
||||
{description}
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
|
||||
<div class={`rounded-lg border px-4 py-3 text-sm ${badgeClass}`}>
|
||||
<p class="font-medium">{message}</p>
|
||||
</div>
|
||||
|
||||
<AlertDialog.Footer>
|
||||
<AlertDialog.Action onclick={() => handleOpenChange(false)}>Close</AlertDialog.Action>
|
||||
</AlertDialog.Footer>
|
||||
</AlertDialog.Content>
|
||||
</AlertDialog.Root>
|
||||
@@ -1,66 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { AlertTriangle } from '@lucide/svelte';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { maxContextError, clearMaxContextError } from '$lib/stores/chat.svelte';
|
||||
</script>
|
||||
|
||||
<AlertDialog.Root
|
||||
open={maxContextError() !== null}
|
||||
onOpenChange={(open) => !open && clearMaxContextError()}
|
||||
>
|
||||
<AlertDialog.Content>
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title class="flex items-center gap-2">
|
||||
<AlertTriangle class="h-5 w-5 text-destructive" />
|
||||
|
||||
Message Too Long
|
||||
</AlertDialog.Title>
|
||||
|
||||
<AlertDialog.Description>
|
||||
Your message exceeds the model's context window and cannot be processed.
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
|
||||
{#if maxContextError()}
|
||||
<div class="space-y-3 text-sm">
|
||||
<div class="rounded-lg bg-muted p-3">
|
||||
<div class="mb-2 font-medium">Token Usage:</div>
|
||||
|
||||
<div class="space-y-1 text-muted-foreground">
|
||||
<div>
|
||||
Estimated tokens:
|
||||
|
||||
<span class="font-mono">
|
||||
{maxContextError()?.estimatedTokens.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
Context window:
|
||||
|
||||
<span class="font-mono">
|
||||
{maxContextError()?.maxContext.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div class="mb-2 font-medium">Suggestions:</div>
|
||||
|
||||
<ul class="list-inside list-disc space-y-1 text-muted-foreground">
|
||||
<li>Shorten your message</li>
|
||||
|
||||
<li>Remove some file attachments</li>
|
||||
|
||||
<li>Start a new conversation</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<AlertDialog.Footer>
|
||||
<AlertDialog.Action onclick={() => clearMaxContextError()}>Got it</AlertDialog.Action>
|
||||
</AlertDialog.Footer>
|
||||
</AlertDialog.Content>
|
||||
</AlertDialog.Root>
|
||||
@@ -30,12 +30,11 @@ export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
|
||||
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
|
||||
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
|
||||
|
||||
export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte';
|
||||
export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte';
|
||||
|
||||
export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte';
|
||||
|
||||
export { default as MaximumContextAlertDialog } from './dialogs/MaximumContextAlertDialog.svelte';
|
||||
|
||||
export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte';
|
||||
|
||||
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
|
||||
|
||||
@@ -12,6 +12,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||
pasteLongTextToFileLen: 2500,
|
||||
pdfAsImage: false,
|
||||
showModelInfo: false,
|
||||
renderUserContentAsMarkdown: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'top_k;typ_p;top_p;min_p;temperature',
|
||||
temperature: 0.8,
|
||||
@@ -84,6 +85,7 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
'Ask for confirmation before automatically changing conversation title when editing the first message.',
|
||||
pdfAsImage: 'Parse PDF as image instead of text (requires vision-capable model).',
|
||||
showModelInfo: 'Display the model name used to generate each message below the message content.',
|
||||
renderUserContentAsMarkdown: 'Render user messages using markdown formatting in the chat.',
|
||||
pyInterpreterEnabled:
|
||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.'
|
||||
};
|
||||
|
||||
@@ -13,7 +13,7 @@ import { slotsService } from './slots';
|
||||
* - Manages streaming and non-streaming response parsing
|
||||
* - Provides request abortion capabilities
|
||||
* - Converts database messages to API format
|
||||
* - Handles error translation and context detection
|
||||
* - Handles error translation for server responses
|
||||
*
|
||||
* - **ChatStore**: Stateful orchestration and UI state management
|
||||
* - Uses ChatService for all AI model communication
|
||||
@@ -26,7 +26,6 @@ import { slotsService } from './slots';
|
||||
* - Streaming response handling with real-time callbacks
|
||||
* - Reasoning content extraction and processing
|
||||
* - File attachment processing (images, PDFs, audio, text)
|
||||
* - Context error detection and reporting
|
||||
* - Request lifecycle management (abort, cleanup)
|
||||
*/
|
||||
export class ChatService {
|
||||
@@ -209,10 +208,13 @@ export class ChatService {
|
||||
userFriendlyError = new Error(
|
||||
'Unable to connect to server - please check if the server is running'
|
||||
);
|
||||
userFriendlyError.name = 'NetworkError';
|
||||
} else if (error.message.includes('ECONNREFUSED')) {
|
||||
userFriendlyError = new Error('Connection refused - server may be offline');
|
||||
userFriendlyError.name = 'NetworkError';
|
||||
} else if (error.message.includes('ETIMEDOUT')) {
|
||||
userFriendlyError = new Error('Request timeout - server may be overloaded');
|
||||
userFriendlyError = new Error('Request timed out - the server took too long to respond');
|
||||
userFriendlyError.name = 'TimeoutError';
|
||||
} else {
|
||||
userFriendlyError = error;
|
||||
}
|
||||
@@ -262,6 +264,7 @@ export class ChatService {
|
||||
let fullReasoningContent = '';
|
||||
let hasReceivedData = false;
|
||||
let lastTimings: ChatMessageTimings | undefined;
|
||||
let streamFinished = false;
|
||||
|
||||
try {
|
||||
let chunk = '';
|
||||
@@ -277,18 +280,8 @@ export class ChatService {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6);
|
||||
if (data === '[DONE]') {
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
return;
|
||||
}
|
||||
|
||||
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
|
||||
|
||||
return;
|
||||
streamFinished = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -326,13 +319,13 @@ export class ChatService {
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
return;
|
||||
if (streamFinished) {
|
||||
if (!hasReceivedData && aggregatedContent.length === 0) {
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
|
||||
}
|
||||
} catch (error) {
|
||||
const err = error instanceof Error ? error : new Error('Stream error');
|
||||
@@ -368,12 +361,8 @@ export class ChatService {
|
||||
const responseText = await response.text();
|
||||
|
||||
if (!responseText.trim()) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
throw contextError;
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
const data: ApiChatCompletionResponse = JSON.parse(responseText);
|
||||
@@ -385,22 +374,14 @@ export class ChatService {
|
||||
}
|
||||
|
||||
if (!content.trim()) {
|
||||
const contextError = new Error(
|
||||
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
|
||||
);
|
||||
contextError.name = 'ContextError';
|
||||
onError?.(contextError);
|
||||
throw contextError;
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
onComplete?.(content, reasoningContent);
|
||||
|
||||
return content;
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === 'ContextError') {
|
||||
throw error;
|
||||
}
|
||||
|
||||
const err = error instanceof Error ? error : new Error('Parse error');
|
||||
|
||||
onError?.(err);
|
||||
@@ -594,37 +575,19 @@ export class ChatService {
|
||||
const errorText = await response.text();
|
||||
const errorData: ApiErrorResponse = JSON.parse(errorText);
|
||||
|
||||
if (errorData.error?.type === 'exceed_context_size_error') {
|
||||
const contextError = errorData.error as ApiContextSizeError;
|
||||
const error = new Error(contextError.message);
|
||||
error.name = 'ContextError';
|
||||
// Attach structured context information
|
||||
(
|
||||
error as Error & {
|
||||
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
|
||||
}
|
||||
).contextInfo = {
|
||||
promptTokens: contextError.n_prompt_tokens,
|
||||
maxContext: contextError.n_ctx,
|
||||
estimatedTokens: contextError.n_prompt_tokens
|
||||
};
|
||||
return error;
|
||||
}
|
||||
|
||||
// Fallback for other error types
|
||||
const message = errorData.error?.message || 'Unknown server error';
|
||||
return new Error(message);
|
||||
const error = new Error(message);
|
||||
error.name = response.status === 400 ? 'ServerError' : 'HttpError';
|
||||
|
||||
return error;
|
||||
} catch {
|
||||
// If we can't parse the error response, return a generic error
|
||||
return new Error(`Server error (${response.status}): ${response.statusText}`);
|
||||
const fallback = new Error(`Server error (${response.status}): ${response.statusText}`);
|
||||
fallback.name = 'HttpError';
|
||||
return fallback;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the processing state with timing information from the server response
|
||||
* @param timings - Timing data from the API response
|
||||
* @param promptProgress - Progress data from the API response
|
||||
*/
|
||||
private updateProcessingState(
|
||||
timings?: ChatMessageTimings,
|
||||
promptProgress?: ChatMessagePromptProgress
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
import { slotsService } from './slots';
|
||||
|
||||
export interface ContextCheckResult {
|
||||
wouldExceed: boolean;
|
||||
currentUsage: number;
|
||||
maxContext: number;
|
||||
availableTokens: number;
|
||||
reservedTokens: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* ContextService - Context window management and limit checking
|
||||
*
|
||||
* This service provides context window monitoring and limit checking using real-time
|
||||
* server data from the slots service. It helps prevent context overflow by tracking
|
||||
* current usage and calculating available space for new content.
|
||||
*
|
||||
* **Architecture & Relationships:**
|
||||
* - **ContextService** (this class): Context limit monitoring
|
||||
* - Uses SlotsService for real-time context usage data
|
||||
* - Calculates available tokens with configurable reserves
|
||||
* - Provides context limit checking and error messaging
|
||||
* - Helps prevent context window overflow
|
||||
*
|
||||
* - **SlotsService**: Provides current context usage from server slots
|
||||
* - **ChatStore**: Uses context checking before sending messages
|
||||
* - **UI Components**: Display context usage warnings and limits
|
||||
*
|
||||
* **Key Features:**
|
||||
* - **Real-time Context Checking**: Uses live server data for accuracy
|
||||
* - **Token Reservation**: Reserves tokens for response generation
|
||||
* - **Limit Detection**: Prevents context window overflow
|
||||
* - **Usage Reporting**: Detailed context usage statistics
|
||||
* - **Error Messaging**: User-friendly context limit messages
|
||||
* - **Configurable Reserves**: Adjustable token reservation for responses
|
||||
*
|
||||
* **Context Management:**
|
||||
* - Monitors current context usage from active slots
|
||||
* - Calculates available space considering reserved tokens
|
||||
* - Provides early warning before context limits are reached
|
||||
* - Helps optimize conversation length and content
|
||||
*/
|
||||
export class ContextService {
|
||||
private reserveTokens: number;
|
||||
|
||||
constructor(reserveTokens = 512) {
|
||||
this.reserveTokens = reserveTokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the context limit would be exceeded
|
||||
*
|
||||
* @returns {Promise<ContextCheckResult | null>} Promise that resolves to the context check result or null if an error occurs
|
||||
*/
|
||||
async checkContextLimit(): Promise<ContextCheckResult | null> {
|
||||
try {
|
||||
const currentState = await slotsService.getCurrentState();
|
||||
|
||||
if (!currentState) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const maxContext = currentState.contextTotal;
|
||||
const currentUsage = currentState.contextUsed;
|
||||
const availableTokens = maxContext - currentUsage - this.reserveTokens;
|
||||
const wouldExceed = availableTokens <= 0;
|
||||
|
||||
return {
|
||||
wouldExceed,
|
||||
currentUsage,
|
||||
maxContext,
|
||||
availableTokens: Math.max(0, availableTokens),
|
||||
reservedTokens: this.reserveTokens
|
||||
};
|
||||
} catch (error) {
|
||||
console.warn('Error checking context limit:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a formatted error message for context limit exceeded
|
||||
*
|
||||
* @param {ContextCheckResult} result - Context check result
|
||||
* @returns {string} Formatted error message
|
||||
*/
|
||||
getContextErrorMessage(result: ContextCheckResult): string {
|
||||
const usagePercent = Math.round((result.currentUsage / result.maxContext) * 100);
|
||||
return `Context window is nearly full. Current usage: ${result.currentUsage.toLocaleString()}/${result.maxContext.toLocaleString()} tokens (${usagePercent}%). Available space: ${result.availableTokens.toLocaleString()} tokens (${result.reservedTokens} reserved for response).`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of tokens to reserve for response generation
|
||||
*
|
||||
* @param {number} tokens - Number of tokens to reserve
|
||||
*/
|
||||
setReserveTokens(tokens: number): void {
|
||||
this.reserveTokens = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
export const contextService = new ContextService();
|
||||
@@ -1,3 +1,2 @@
|
||||
export { chatService } from './chat';
|
||||
export { contextService } from './context';
|
||||
export { slotsService } from './slots';
|
||||
|
||||
@@ -39,7 +39,6 @@ import type { ExportedConversations } from '$lib/types/database';
|
||||
* - Conversation branching for exploring different response paths
|
||||
* - Streaming AI responses with real-time content updates
|
||||
* - File attachment support (images, PDFs, text files, audio)
|
||||
* - Context window management with error recovery
|
||||
* - Partial response saving when generation is interrupted
|
||||
* - Message editing with automatic response regeneration
|
||||
*/
|
||||
@@ -48,11 +47,9 @@ class ChatStore {
|
||||
activeMessages = $state<DatabaseMessage[]>([]);
|
||||
conversations = $state<DatabaseConversation[]>([]);
|
||||
currentResponse = $state('');
|
||||
errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null);
|
||||
isInitialized = $state(false);
|
||||
isLoading = $state(false);
|
||||
maxContextError = $state<{ message: string; estimatedTokens: number; maxContext: number } | null>(
|
||||
null
|
||||
);
|
||||
titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise<boolean>;
|
||||
|
||||
constructor() {
|
||||
@@ -69,8 +66,6 @@ class ChatStore {
|
||||
try {
|
||||
await this.loadConversations();
|
||||
|
||||
this.maxContextError = null;
|
||||
|
||||
this.isInitialized = true;
|
||||
} catch (error) {
|
||||
console.error('Failed to initialize chat store:', error);
|
||||
@@ -99,8 +94,6 @@ class ChatStore {
|
||||
this.activeConversation = conversation;
|
||||
this.activeMessages = [];
|
||||
|
||||
this.maxContextError = null;
|
||||
|
||||
await goto(`#/chat/${conversation.id}`);
|
||||
|
||||
return conversation.id;
|
||||
@@ -133,8 +126,6 @@ class ChatStore {
|
||||
this.activeMessages = await DatabaseStore.getConversationMessages(convId);
|
||||
}
|
||||
|
||||
this.maxContextError = null;
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error('Failed to load conversation:', error);
|
||||
@@ -418,56 +409,6 @@ class ChatStore {
|
||||
return;
|
||||
}
|
||||
|
||||
if (error.name === 'ContextError') {
|
||||
console.warn('Context error detected:', error.message);
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
|
||||
const messageIndex = this.activeMessages.findIndex(
|
||||
(m: DatabaseMessage) => m.id === assistantMessage.id
|
||||
);
|
||||
|
||||
if (messageIndex !== -1) {
|
||||
this.activeMessages.splice(messageIndex, 1);
|
||||
DatabaseStore.deleteMessage(assistantMessage.id).catch(console.error);
|
||||
}
|
||||
|
||||
// Use structured context info from new exceed_context_size_error format if available
|
||||
const contextInfo = (
|
||||
error as Error & {
|
||||
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
|
||||
}
|
||||
).contextInfo;
|
||||
let estimatedTokens = 0;
|
||||
let maxContext = serverStore.serverProps?.default_generation_settings.n_ctx || 8192;
|
||||
|
||||
if (contextInfo) {
|
||||
// Use precise token counts from server response
|
||||
estimatedTokens = contextInfo.promptTokens;
|
||||
maxContext = contextInfo.maxContext;
|
||||
} else {
|
||||
// Fallback to estimation for older error format
|
||||
try {
|
||||
// Rough estimation: ~4 characters per token
|
||||
const messageContent = JSON.stringify(messages);
|
||||
estimatedTokens = Math.ceil(messageContent.length / 4);
|
||||
} catch {
|
||||
estimatedTokens = 0;
|
||||
}
|
||||
}
|
||||
|
||||
this.maxContextError = {
|
||||
message: error.message,
|
||||
estimatedTokens,
|
||||
maxContext
|
||||
};
|
||||
|
||||
if (onError) {
|
||||
onError(error);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
console.error('Streaming error:', error);
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
@@ -477,9 +418,19 @@ class ChatStore {
|
||||
);
|
||||
|
||||
if (messageIndex !== -1) {
|
||||
this.activeMessages[messageIndex].content = `Error: ${error.message}`;
|
||||
const [failedMessage] = this.activeMessages.splice(messageIndex, 1);
|
||||
|
||||
if (failedMessage) {
|
||||
DatabaseStore.deleteMessage(failedMessage.id).catch((cleanupError) => {
|
||||
console.error('Failed to remove assistant message after error:', cleanupError);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
|
||||
this.showErrorDialog(dialogType, error.message);
|
||||
|
||||
if (onError) {
|
||||
onError(error);
|
||||
}
|
||||
@@ -487,6 +438,14 @@ class ChatStore {
|
||||
});
|
||||
}
|
||||
|
||||
private showErrorDialog(type: 'timeout' | 'server', message: string): void {
|
||||
this.errorDialogState = { type, message };
|
||||
}
|
||||
|
||||
dismissErrorDialog(): void {
|
||||
this.errorDialogState = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if an error is an abort error (user cancelled operation)
|
||||
* @param error - The error to check
|
||||
@@ -574,6 +533,7 @@ class ChatStore {
|
||||
return;
|
||||
}
|
||||
|
||||
this.errorDialogState = null;
|
||||
this.isLoading = true;
|
||||
this.currentResponse = '';
|
||||
|
||||
@@ -603,37 +563,23 @@ class ChatStore {
|
||||
|
||||
const conversationContext = this.activeMessages.slice(0, -1);
|
||||
|
||||
await this.streamChatCompletion(
|
||||
conversationContext,
|
||||
assistantMessage,
|
||||
undefined,
|
||||
(error: Error) => {
|
||||
if (error.name === 'ContextError' && userMessage) {
|
||||
const userMessageIndex = this.findMessageIndex(userMessage.id);
|
||||
|
||||
if (userMessageIndex !== -1) {
|
||||
this.activeMessages.splice(userMessageIndex, 1);
|
||||
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
await this.streamChatCompletion(conversationContext, assistantMessage);
|
||||
} catch (error) {
|
||||
if (this.isAbortError(error)) {
|
||||
this.isLoading = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (error instanceof Error && error.name === 'ContextError' && userMessage) {
|
||||
const userMessageIndex = this.findMessageIndex(userMessage.id);
|
||||
if (userMessageIndex !== -1) {
|
||||
this.activeMessages.splice(userMessageIndex, 1);
|
||||
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
|
||||
}
|
||||
}
|
||||
|
||||
console.error('Failed to send message:', error);
|
||||
this.isLoading = false;
|
||||
if (!this.errorDialogState) {
|
||||
if (error instanceof Error) {
|
||||
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
this.showErrorDialog(dialogType, error.message);
|
||||
} else {
|
||||
this.showErrorDialog('server', 'Unknown error occurred while sending message');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -662,24 +608,6 @@ class ChatStore {
|
||||
this.currentResponse = '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the max context error state
|
||||
* Removes any displayed context limit warnings
|
||||
*/
|
||||
clearMaxContextError(): void {
|
||||
this.maxContextError = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the max context error state
|
||||
* @param error - The context error details or null to clear
|
||||
*/
|
||||
setMaxContextError(
|
||||
error: { message: string; estimatedTokens: number; maxContext: number } | null
|
||||
): void {
|
||||
this.maxContextError = error;
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves partial response if generation was interrupted
|
||||
* Preserves user's partial content and timing data when generation is stopped early
|
||||
@@ -1250,7 +1178,6 @@ class ChatStore {
|
||||
this.activeMessages = [];
|
||||
this.currentResponse = '';
|
||||
this.isLoading = false;
|
||||
this.maxContextError = null;
|
||||
}
|
||||
|
||||
/** Refreshes active messages based on currNode after branch navigation */
|
||||
@@ -1538,6 +1465,7 @@ class ChatStore {
|
||||
private async generateResponseForMessage(userMessageId: string): Promise<void> {
|
||||
if (!this.activeConversation) return;
|
||||
|
||||
this.errorDialogState = null;
|
||||
this.isLoading = true;
|
||||
this.currentResponse = '';
|
||||
|
||||
@@ -1584,7 +1512,7 @@ export const activeMessages = () => chatStore.activeMessages;
|
||||
export const isLoading = () => chatStore.isLoading;
|
||||
export const currentResponse = () => chatStore.currentResponse;
|
||||
export const isInitialized = () => chatStore.isInitialized;
|
||||
export const maxContextError = () => chatStore.maxContextError;
|
||||
export const errorDialog = () => chatStore.errorDialogState;
|
||||
|
||||
export const createConversation = chatStore.createConversation.bind(chatStore);
|
||||
export const downloadConversation = chatStore.downloadConversation.bind(chatStore);
|
||||
@@ -1592,9 +1520,9 @@ export const exportAllConversations = chatStore.exportAllConversations.bind(chat
|
||||
export const importConversations = chatStore.importConversations.bind(chatStore);
|
||||
export const deleteConversation = chatStore.deleteConversation.bind(chatStore);
|
||||
export const sendMessage = chatStore.sendMessage.bind(chatStore);
|
||||
export const dismissErrorDialog = chatStore.dismissErrorDialog.bind(chatStore);
|
||||
|
||||
export const gracefulStop = chatStore.gracefulStop.bind(chatStore);
|
||||
export const clearMaxContextError = chatStore.clearMaxContextError.bind(chatStore);
|
||||
export const setMaxContextError = chatStore.setMaxContextError.bind(chatStore);
|
||||
|
||||
// Branching operations
|
||||
export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatStore);
|
||||
|
||||
@@ -197,7 +197,7 @@ class ServerStore {
|
||||
errorMessage = 'Server not found - check server address';
|
||||
isOfflineLikeError = true;
|
||||
} else if (error.message.includes('ETIMEDOUT')) {
|
||||
errorMessage = 'Connection timeout - server may be overloaded';
|
||||
errorMessage = 'Request timed out - the server took too long to respond';
|
||||
isOfflineLikeError = true;
|
||||
} else if (error.message.includes('503')) {
|
||||
errorMessage = 'Server temporarily unavailable - try again shortly';
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
<script lang="ts">
|
||||
import '../app.css';
|
||||
import { page } from '$app/state';
|
||||
import {
|
||||
ChatSidebar,
|
||||
ConversationTitleUpdateDialog,
|
||||
MaximumContextAlertDialog
|
||||
} from '$lib/components/app';
|
||||
import { ChatSidebar, ConversationTitleUpdateDialog } from '$lib/components/app';
|
||||
import {
|
||||
activeMessages,
|
||||
isLoading,
|
||||
@@ -145,8 +141,6 @@
|
||||
|
||||
<Toaster richColors />
|
||||
|
||||
<MaximumContextAlertDialog />
|
||||
|
||||
<ConversationTitleUpdateDialog
|
||||
bind:open={titleUpdateDialogOpen}
|
||||
currentTitle={titleUpdateCurrentTitle}
|
||||
|
||||
Reference in New Issue
Block a user