Compare commits

..

16 Commits

Author SHA1 Message Date
Johannes Gäßler 7049736b2d CUDA: fix numerical issues in tile FA kernel (#16540) 2025-10-13 17:29:45 +03:00
Jie Fu (傅杰) 01d2bdc2bc ggml : fix build broken with -march=armv9-a on MacOS (#16520)
* ggml : fix build broken with -march=armv9-a on MacOS

Signed-off-by: Jie Fu <jiefu@tencent.com>

* Add #pragma message

Signed-off-by: Jie Fu <jiefu@tencent.com>

* Address review comment.

Signed-off-by: Jie Fu <jiefu@tencent.com>

* Update ggml/src/ggml-cpu/ggml-cpu.c

---------

Signed-off-by: Jie Fu <jiefu@tencent.com>
Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-10-13 15:48:47 +03:00
Chenguang Li 56fc38b965 CANN: fix CPU memory leak in CANN backend (#16549)
This commit fixes a CPU-side memory leak issue in the CANN backend,
which occurred when intermediate aclTensorList objects were not properly
released after operator execution. The leak happened during repeated
invocations of CANN ops (e.g., FlashAttention), leading to increasing
host memory usage over time.

Proper resource cleanup (aclDestroyTensorList and related release logic)
has been added to ensure that all temporary tensors are correctly freed.
2025-10-13 17:01:24 +08:00
Pascal 1fb9504eb7 fix: add remark plugin to render raw HTML as literal text (#16505)
* fix: add remark plugin to render raw HTML as literal text

Implemented a missing MDAST stage to neutralize raw HTML like major LLM WebUIs
do ensuring consistent and safe Markdown rendering

Introduced 'remarkLiteralHtml', a plugin that converts raw HTML nodes in the
Markdown AST into plain-text equivalents while preserving indentation and
line breaks. This ensures consistent rendering and prevents unintended HTML
execution, without altering valid Markdown structure

Kept 'remarkRehype' in the pipeline since it performs the required conversion
from MDAST to HAST for KaTeX, syntax highlighting, and HTML serialization

Refined the link-enhancement logic to skip unnecessary DOM rewrites,
fixing a subtle bug where extra paragraphs were injected after the first
line due to full innerHTML reconstruction, and ensuring links open in new
tabs only when required

Final pipeline: remarkGfm -> remarkMath -> remarkBreaks -> remarkLiteralHtml
-> remarkRehype -> rehypeKatex -> rehypeHighlight -> rehypeStringify

* fix: address review feedback from allozaur

* chore: update webui build output
2025-10-13 10:55:32 +02:00
Sam/Samuel 3f750f8d76 metal: add support for opt_step_sgd (#16539)
* metal: add support for opt_step_sgd

* add newline to pass EditorConfig check
2025-10-13 11:25:02 +03:00
Georgi Gerganov c515fc5771 ggml : fix scalar path for computing norm (#16558) 2025-10-13 11:22:27 +03:00
hipudding f9bc66c3eb CANN: Update several operators to support FP16 data format (#16251)
Many Ascend operators internally use FP16 precision for computation.
If input data is in FP32, it must first be cast to FP16 before
computation, and then cast back to FP32 after computation, which
introduces unnecessary cast operations. Moreover, FP16 computation
requires significantly less workload compared to FP32, leading to
noticeable efficiency improvements.

In this change, `get_rows`, `rms_norm`, and `flash_attn_ext` are extended
to support multiple data types. Validation on the Qwen2 0.5b model shows
correct accuracy and about 10% performance gain in concurrent scenarios.

Co-authored-by: noemotiovon <757486878@qq.com>
2025-10-13 08:52:22 +08:00
Sam/Samuel a31cf36ad9 metal : add opt_step_adamw and op_sum (#16529)
* scaffold to support opt step adamw on metal (not written so far)

* add opt-step-adamw kernel for metal

* pass op->src[4] as a separate buffer to the pipeline

* add bounds check to opt-step-adamw kernel

* complete scaffold for GGML_OP_SUM

* naive GGML_OP_SUM kernel

* remove unwanted comment

* change OP_SUM capability gate

* Add has_simdgroup_reduction to both ops to pass CI
2025-10-12 21:43:14 +03:00
Pascal 81d54bbfd5 webui: remove client-side context pre-check and rely on backend for limits (#16506)
* fix: make SSE client robust to premature [DONE] in agentic proxy chains

* webui: remove client-side context pre-check and rely on backend for limits

Removed the client-side context window pre-check and now simply sends messages
while keeping the dialog imports limited to core components, eliminating the
maximum context alert path

Simplified streaming and non-streaming chat error handling to surface a generic
'No response received from server' error whenever the backend returns no content

Removed the obsolete maxContextError plumbing from the chat store so state
management now focuses on the core message flow without special context-limit cases

* webui: cosmetic rename of error messages

* Update tools/server/webui/src/lib/stores/chat.svelte.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/stores/chat.svelte.ts

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* Update tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>

* chore: update webui build output

---------

Co-authored-by: Aleksander Grygier <aleksander.grygier@gmail.com>
2025-10-12 18:06:41 +02:00
Neo Zhang Jianyu c7be9febcb [SYCL] fix UT fault cases: count-equal, argsort, pad OPs (#16521)
* fix/refactor OP argsort, pad

* fix count-equal op

* update SYCL OP list

* fix format issue

---------

Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
2025-10-12 21:53:35 +08:00
Mathieu Baudier 8415f61e23 ci : add Vulkan on Ubuntu with default packages build (#16532)
* ci: build Vulkan on Ubuntu with default packages

* ci: disable tests in Vulkan build with default Ubuntu packages
2025-10-12 15:48:03 +02:00
Aldehir Rojas 2c301e91ab common : handle unicode during partial json parsing (#16526)
* common : handle unicode during partial json parsing

* common : set missing `ensure_ascii = true` during json dump
2025-10-12 16:18:47 +03:00
Georgi Gerganov 4b2dae383d common : update presets (#16504)
* presets : add --embd-gemma-default and remove old embedding presets

* presets : add gpt-oss presets

* presets : add vision presets

* cont : remove reasoning overrides [no ci]

* cont : fix batch size for embedding gemma [no ci]
2025-10-12 09:29:13 +03:00
sirus20x6 41aac5c69b ggml : Fix FP16 ELU positive branch (#16519)
Co-authored-by: Aaron <shelhamer.aaron@gmail.com>
2025-10-12 08:25:37 +03:00
Daniel Bevenius a2fba89a42 hparams : add check for layer index in is_recurrent (#16511)
* hparams : add check for layer index in is_recurrent

This commit adds a check in the is_recurrent method to ensure that the
provided layer index is within the valid range.

The motivation for this change is to prevent potential out-of-bounds
and also be consistent with other methods in the class that perform
similar checks, like is_swa.
2025-10-12 07:19:06 +02:00
sirus20x6 20cc625edc ggml: Correct SVE implementation in ggml_vec_dot_f16_unroll (#16518)
The previous SVE implementation for `ggml_vec_dot_f16_unroll` contained a bug due to a copy-paste error. The wrong variable was used in an FMA instruction, leading to incorrect results. This commit corrects the variable usage and improves the clarity of the code by renaming variables to avoid confusion.

Co-authored-by: Aaron <shelhamer.aaron@gmail.com>
2025-10-12 08:15:00 +03:00
50 changed files with 13499 additions and 5042 deletions
+33
View File
@@ -387,6 +387,39 @@ jobs:
cd build
ctest -L main --verbose
ubuntu-24-cmake-vulkan-deb:
runs-on: ubuntu-24.04
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-24-cmake-vulkan-deb
evict-old-files: 1d
- name: Dependencies
id: depends
run: |
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
- name: Configure
id: cmake_configure
run: |
cmake -B build \
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
-DGGML_BACKEND_DL=ON \
-DGGML_CPU_ALL_VARIANTS=ON \
-DGGML_VULKAN=ON
- name: Build
id: cmake_build
run: |
cmake --build build -j $(nproc)
ubuntu-24-cmake-vulkan:
runs-on: ubuntu-24.04
+161 -133
View File
@@ -3358,7 +3358,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg(
{"--chat-template-kwargs"}, "STRING",
string_format("sets additional params for the json template parser"),
[](common_params & params, const std::string & value) {
[](common_params & params, const std::string & value) {
auto parsed = json::parse(value);
for (const auto & item : parsed.items()) {
params.default_template_kwargs[item.key()] = item.value().dump();
@@ -3570,21 +3570,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_file(common_log_main(), value.c_str());
}
));
add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
"'auto' enables colors when output is to a terminal",
[](common_params &, const std::string & value) {
if (is_truthy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
} else if (is_falsey(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
} else if (is_autoy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
} else {
throw std::invalid_argument(
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
}
}).set_env("LLAMA_LOG_COLORS"));
add_opt(common_arg(
{"--log-colors"}, "[on|off|auto]",
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
"'auto' enables colors when output is to a terminal",
[](common_params &, const std::string & value) {
if (is_truthy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
} else if (is_falsey(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
} else if (is_autoy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
} else {
throw std::invalid_argument(
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
}
}
).set_env("LLAMA_LOG_COLORS"));
add_opt(common_arg(
{"-v", "--verbose", "--log-verbose"},
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)",
@@ -3850,7 +3852,87 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_TTS}));
// model-specific
add_opt(common_arg(
{"--diffusion-steps"}, "N",
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
[](common_params & params, int value) { params.diffusion.steps = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-visual"},
string_format("enable visual diffusion mode (show progressive generation) (default: %s)", params.diffusion.visual_mode ? "true" : "false"),
[](common_params & params) { params.diffusion.visual_mode = true; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-eps"}, "F",
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-algorithm"}, "N",
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm),
[](common_params & params, int value) { params.diffusion.algorithm = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-alg-temp"}, "F",
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-block-length"}, "N",
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
[](common_params & params, int value) { params.diffusion.block_length = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-cfg-scale"}, "F",
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{"--diffusion-add-gumbel-noise"}, "F",
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "-lr", "--learning-rate" }, "ALPHA",
string_format("adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)", (double) params.lr.lr0),
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
string_format("(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
(double) params.lr.lr_min),
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-decay-epochs", "--learning-rate-decay-epochs"}, "ALPHA",
string_format("(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)", (double) params.lr.decay_epochs),
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-wd", "--weight-decay"}, "WD",
string_format("adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).", (double) params.lr.wd),
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-val-split", "--val-split"}, "FRACTION",
string_format("fraction of data to use as validation set for training (default: %.2g).", (double) params.val_split),
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-epochs", "--epochs"}, "N",
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
[](common_params & params, int epochs) { params.lr.epochs = epochs; }
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"-opt", "--optimizer"}, "sgd|adamw", "adamw or sgd",
[](common_params & params, const std::string & name) {
params.optimizer = common_opt_get_optimizer(name.c_str());
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
}
}
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
// presets
add_opt(common_arg(
{"--tts-oute-default"},
string_format("use default OuteTTS models (note: can download weights from the internet)"),
@@ -3863,39 +3945,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_TTS}));
add_opt(common_arg(
{"--embd-bge-small-en-default"},
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
{"--embd-gemma-default"},
string_format("use default EmbeddingGemma model (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--embd-e5-small-en-default"},
string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
params.model.hf_file = "e5-small-v2-q8_0.gguf";
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--embd-gte-small-default"},
string_format("use default gte-small model (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
params.model.hf_file = "gte-small-q8_0.gguf";
params.embd_normalize = 2;
params.n_ctx = 512;
params.model.hf_repo = "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF";
params.model.hf_file = "embeddinggemma-300M-qat-Q4_0.gguf";
params.port = 8011;
params.n_ubatch = 2048;
params.n_batch = 2048;
params.n_parallel = 32;
params.n_ctx = 2048*params.n_parallel;
params.verbose_prompt = true;
params.embedding = true;
}
@@ -3990,96 +4049,65 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{ "--diffusion-steps" }, "N",
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
[](common_params & params, int value) { params.diffusion.steps = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-visual" },
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
params.diffusion.visual_mode ? "true" : "false"),
[](common_params & params) { params.diffusion.visual_mode = true; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
{"--gpt-oss-20b-default"},
string_format("use gpt-oss-20b (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gpt-oss-20b-GGUF";
params.model.hf_file = "gpt-oss-20b-mxfp4.gguf";
params.port = 8013;
params.n_ubatch = 2048;
params.n_batch = 32768;
params.n_parallel = 2;
params.n_ctx = 131072*params.n_parallel;
params.sampling.temp = 1.0f;
params.sampling.top_p = 1.0f;
params.sampling.top_k = 0;
params.sampling.min_p = 0.01f;
params.use_jinja = true;
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{ "--diffusion-eps" }, "F",
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-algorithm" }, "N",
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)",
params.diffusion.algorithm),
[](common_params & params, int value) { params.diffusion.algorithm = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-alg-temp" }, "F",
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
{"--gpt-oss-120b-default"},
string_format("use gpt-oss-120b (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gpt-oss-120b-GGUF";
params.port = 8013;
params.n_ubatch = 2048;
params.n_batch = 32768;
params.n_parallel = 2;
params.n_ctx = 131072*params.n_parallel;
params.sampling.temp = 1.0f;
params.sampling.top_p = 1.0f;
params.sampling.top_k = 0;
params.sampling.min_p = 0.01f;
params.use_jinja = true;
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{ "--diffusion-block-length" }, "N",
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
[](common_params & params, int value) { params.diffusion.block_length = value; }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-cfg-scale" }, "F",
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
add_opt(common_arg(
{ "--diffusion-add-gumbel-noise" }, "F",
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
{"--vision-gemma-4b-default"},
string_format("use Gemma 3 4B QAT (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gemma-3-4b-it-qat-GGUF";
params.port = 8014;
params.n_ctx = 0;
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
string_format(
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
(double) params.lr.lr0),
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
string_format(
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
(double) params.lr.lr_min),
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
string_format(
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
(double) params.lr.decay_epochs),
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{ "-wd", "--weight-decay" }, "WD",
string_format(
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
(double) params.lr.wd),
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
string_format("fraction of data to use as validation set for training (default: %.2g).",
(double) params.val_split),
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
[](common_params & params, const std::string & name) {
params.optimizer = common_opt_get_optimizer(name.c_str());
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
}
})
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
{"--vision-gemma-12b-default"},
string_format("use Gemma 3 12B QAT (note: can download weights from the internet)"),
[](common_params & params) {
params.model.hf_repo = "ggml-org/gemma-3-12b-it-qat-GGUF";
params.port = 8014;
params.n_ctx = 0;
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
return ctx_arg;
}
+2 -2
View File
@@ -432,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
if (is_arguments_path({})) {
// Entire JSON is the arguments and was parsed fully.
return consume_json_result {
partial->json.dump(),
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
/* .is_partial = */ false,
};
}
@@ -444,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
std::vector<std::string> path;
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
if (is_arguments_path(path)) {
auto arguments = j.dump();
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
if (is_partial() && !partial->healing_marker.marker.empty()) {
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
if (idx != std::string::npos) {
+1 -1
View File
@@ -426,7 +426,7 @@ struct common_params {
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc.
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
+51
View File
@@ -5,6 +5,7 @@
#include <nlohmann/json.hpp>
#include <string>
#include <regex>
using json = nlohmann::ordered_json;
@@ -168,6 +169,47 @@ bool common_json_parse(
}
}
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
auto is_high_surrogate = [&](const std::string & s) {
// Check if a partial of a high surrogate (U+D800-U+DBFF)
return s.length() >= 4 &&
s[0] == '\\' && s[1] == 'u' &&
std::tolower(s[2]) == 'd' &&
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
};
// Initialize the unicode marker to a low surrogate to handle the edge case
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
// backslash (\)
std::string unicode_marker_padding = "udc00";
std::smatch last_unicode_seq;
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
std::smatch second_last_seq;
std::string prelude = str.substr(0, last_unicode_seq.position());
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
if (is_high_surrogate(last_unicode_seq.str())) {
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
unicode_marker_padding += "\\udc00";
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
if (is_high_surrogate(second_last_seq.str())) {
// If this follows a high surrogate, pad it to be a low surrogate
if (last_unicode_seq.length() == 2) {
unicode_marker_padding = "dc00";
} else if (last_unicode_seq.length() == 3) {
unicode_marker_padding = "c00";
} else {
// The original unicode_marker_padding is already padded with 0s
}
}
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
@@ -186,6 +228,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an object value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
@@ -205,6 +250,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an array value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
@@ -230,6 +278,9 @@ bool common_json_parse(
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
// Was inside an object key string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
+10 -8
View File
@@ -31,7 +31,7 @@ Legend:
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
@@ -51,7 +51,7 @@ Legend:
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
@@ -65,11 +65,11 @@ Legend:
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ❌ |
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
@@ -92,9 +92,9 @@ Legend:
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | | ✅ | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | | ✅ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
@@ -102,9 +102,11 @@ Legend:
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+12095 -4249
View File
File diff suppressed because it is too large Load Diff
+100 -107
View File
@@ -146,9 +146,7 @@ void ggml_cann_op_unary_gated(
unary_op(ctx, acl_src0, acl_dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
ggml_cann_release_resources(ctx, acl_src0, acl_dst);
if(src1)
ggml_cann_release_resources(ctx, acl_src1);
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
}
/**
@@ -894,14 +892,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
}
/**
* @brief Get or expand a cached float32 tensor filled with a scalar value.
* @brief Get or expand a cached tensor filled with a scalar value.
*
* This function manages cached device memory for float32 tensors. If the current
* This function manages cached device memory for tensors. If the current
* cache size is insufficient for the requested tensor shape, the old memory will
* be released and new memory will be allocated. The allocated buffer is then
* initialized either with zeros (when @p value == 0.0f) or with the given scalar
* value using CANN operations. Finally, an aclTensor object is created from the
* cached memory and returned.
* be released and new memory will be allocated. The allocated buffer is
* initialized with the given scalar value using CANN operations.
* Finally, an aclTensor object is created from the cached memory and returned.
*
* @param ctx The CANN backend context that manages device memory.
* @param buffer A pointer to the cached device buffer (will be allocated
@@ -910,17 +907,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
* updated when the cache is expanded.
* @param ne The tensor shape array (number of elements in each dimension).
* @param nb The stride size for each dimension.
* @param dtype Data type of cached tensor.
* @param dims The number of tensor dimensions.
* @param value The scalar value used to fill the tensor (supports zero
* initialization via memset or arbitrary values via fill_scalar).
* @return An aclTensor pointer created from the cached buffer.
*/
static aclTensor* get_f32_cache_acl_tensor(
static aclTensor* get_cache_acl_tensor(
ggml_backend_cann_context& ctx,
void** buffer,
int64_t &cache_element,
int64_t* ne,
size_t* nb,
ggml_type dtype,
int64_t dims,
float value) {
// Calculate total number of elements
@@ -928,7 +927,7 @@ static aclTensor* get_f32_cache_acl_tensor(
for (int i = 0; i < dims; i++) {
n_element *= ne[i];
}
size_t size = n_element * sizeof(float);
size_t size = n_element * ggml_type_size(dtype);
// Allocate or expand cache if needed
if (cache_element < n_element) {
@@ -941,19 +940,17 @@ static aclTensor* get_f32_cache_acl_tensor(
cache_element = n_element;
// Initialize cache
if (value == 0.0f) {
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
} else {
int64_t pool_ne[1] = { n_element };
size_t pool_nb[1] = { sizeof(float) };
aclTensor* acl_value = ggml_cann_create_tensor(
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
aclnn_fill_scalar(ctx, 1, acl_value);
ggml_cann_release_resources(ctx, acl_value);
}
int64_t pool_ne[1] = { n_element };
size_t pool_nb[1] = { ggml_type_size(dtype) };
aclTensor* acl_value = ggml_cann_create_tensor(
*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype),
pool_ne, pool_nb, 1);
aclnn_fill_scalar(ctx, value, acl_value);
ggml_cann_release_resources(ctx, acl_value);
}
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype),
ggml_type_size(dtype), ne, nb, dims);
}
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -965,35 +962,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
// build gamma, one...
// build gamma.
size_t acl_gamma_nb[GGML_MAX_DIMS];
acl_gamma_nb[0] = sizeof(float);
// gamma's type is the same with dst.
acl_gamma_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
}
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
aclTensor* acl_gamma = get_cache_acl_tensor(
ctx,
&ctx.rms_norm_one_tensor_cache.cache,
ctx.rms_norm_one_tensor_cache.size,
src->ne,
acl_gamma_nb,
dst->type,
1, // dims
1.0f // value
);
// build rstd, zero...
// build rstd.
int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
// rstd will always be F32.
acl_rstd_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
}
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
aclTensor* acl_rstd = get_cache_acl_tensor(
ctx,
&ctx.rms_norm_zero_tensor_cache.cache,
ctx.rms_norm_zero_tensor_cache.size,
acl_rstd_ne,
acl_rstd_nb,
GGML_TYPE_F32,
GGML_MAX_DIMS - 1,
0.0f // value
);
@@ -1765,33 +1766,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0]; // src
ggml_tensor* src1 = dst->src[1]; // index
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
switch (src0->type) {
case GGML_TYPE_F32: {
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
break;
}
case GGML_TYPE_F16: {
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float));
void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
case GGML_TYPE_F16:
case GGML_TYPE_F32:
if(src0->type == dst->type) {
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
} else {
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = dst->nb[0];
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
}
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
src0->ne, src_trans_nb, GGML_MAX_DIMS);
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
break;
}
case GGML_TYPE_Q8_0: {
// add 1 dim for bcast mul.
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1],
@@ -1799,7 +1802,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1],
*dequant_ne;
int64_t scale_offset = 0;
// [3,4,5,64] -> [3,4,5,2,32]
weight_ne[0] = QK8_0;
weight_ne[1] = src0->ne[0] / QK8_0;
@@ -1809,7 +1811,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
weight_ne[i] = src0->ne[i - 1];
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
}
// [3,4,5,64] -> [3,4,5,2,1]
scale_ne[0] = 1;
scale_ne[1] = src0->ne[0] / QK8_0;
@@ -1819,18 +1820,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
scale_ne[i] = src0->ne[i - 1];
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
}
// [3,4,5,64] -> [3,4,5,2,32]
dequant_ne = weight_ne;
dequant_nb[0] = sizeof(float);
dequant_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
}
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
ggml_cann_pool_alloc dequant_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float));
ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type));
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
GGML_MAX_DIMS + 1);
@@ -1838,22 +1836,20 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
aclTensor* dequant_tensor = ggml_cann_create_tensor(
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
dequant_nb[0] = sizeof(float);
dequant_nb[0] = ggml_type_size(dst->type);
dequant_ne = src0->ne;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
}
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(),
dequant_ne, dequant_nb,
dst->data, dst->ne, dst->nb,
src1, dst->type);
ggml_cann_release_resources(ctx, dequant_tensor);
ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
break;
}
default:
@@ -1965,16 +1961,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
// Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) {
int64_t acl_stride[2] = {1, transpose_ne[1]};
// Reverse ne.
std::reverse(transpose_ne, transpose_ne + n_dims);
std::vector<int64_t> storageDims = {transpose_ne[0], transpose_ne[1]};
acl_weight_tensor = aclCreateTensor(
transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
acl_weight_tensor =
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
} else {
acl_weight_tensor =
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@@ -3178,7 +3166,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclTensor* acl_src0_f16_tensor = nullptr;
aclTensor* acl_src1_f16_tensor = nullptr;
aclTensor* acl_src2_f16_tensor = nullptr;
aclTensor* acl_dst_f16_tensor = nullptr;
// Step 1: cast the src0 (Query) to fp16 if needed
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
@@ -3216,22 +3203,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
src2_bsnd_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}
acl_dst_f16_tensor = ggml_cann_create_tensor(
out_f16_buffer, faDataType, faElemSize,
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
);
// Step 3: create the PSEShift tensor if needed
// this tensor is considered as mask (f16) in the llama.cpp
aclTensor* bcast_pse_tensor = nullptr;
@@ -3317,8 +3288,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
aclTensorList* acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
aclTensorList* acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
int64_t numHeads = src0->ne[2]; // N
int64_t numKeyValueHeads = src1->ne[2];
@@ -3334,8 +3305,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
int64_t keyAntiquantMode = 0;
int64_t valueAntiquantMode = 0;
// Step 5: launch the FusedInferAttentionScoreV2 kernel.
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
aclTensor * fa_dst_tensor = nullptr;
aclTensor * acl_dst_tensor = nullptr;
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
if (dst->type == GGML_TYPE_F32) {
void* out_f16_buffer = out_f16_allocator.alloc(
ggml_nelements(dst) * faElemSize);
int64_t* out_f16_ne = src0_bsnd_ne;
size_t out_f16_nb[GGML_MAX_DIMS];
out_f16_nb[0] = faElemSize;
for(int i = 1; i < GGML_MAX_DIMS; ++i){
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
}
fa_dst_tensor = ggml_cann_create_tensor(
out_f16_buffer, faDataType, faElemSize,
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
);
}
else {
fa_dst_tensor = ggml_cann_create_tensor(dst);
}
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
@@ -3357,23 +3349,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
blockSize, antiquantMode, // blockSize, antiquantMode
softmaxLseFlag, // softmaxLseFlag
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
acl_dst_f16_tensor, // attentionOut
fa_dst_tensor, // attentionOut
nullptr // softmaxLse
);
// Step 6: post-processing, permute and cast to f32
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
// TODO: when dst is fp16, don't need cast
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_src1_f16_tensor,
acl_src2_f16_tensor,
acl_dst_f16_tensor,
acl_dst_tensor);
if(src3 != nullptr){
ggml_cann_release_resources(ctx, bcast_pse_tensor);
if (dst->type == GGML_TYPE_F32) {
// Step 6: post-processing, permute and cast to f32
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
}
}else{
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
acl_k_tensor_list,
acl_v_tensor_list,
fa_dst_tensor,
acl_dst_tensor,
bcast_pse_tensor);
} else {
GGML_ABORT("Function is not implemented.");
}
}
+1 -1
View File
@@ -68,7 +68,7 @@ struct ggml_compute_params {
#endif // __VXE2__
#endif // __s390x__ && __VEC__
#if defined(__ARM_FEATURE_SVE)
#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
#include <sys/prctl.h>
#endif
+6 -1
View File
@@ -689,8 +689,13 @@ bool ggml_is_numa(void) {
#endif
static void ggml_init_arm_arch_features(void) {
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
#if defined(__linux__)
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
#else
// TODO: add support of SVE for non-linux systems
#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
#endif
#endif
}
+1 -1
View File
@@ -463,9 +463,9 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
#endif
for (; i < n; ++i) {
float val = x[i] - mean;
y[i] = val;
val *= val;
sum += (ggml_float)val;
y[i] = val;
}
return sum/n;
}
+5 -4
View File
@@ -144,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
@@ -160,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
@@ -820,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
const float v = GGML_CPU_FP16_TO_FP32(x[i]);
y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
}
}
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+17 -27
View File
@@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
}
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
if (!oob_check || i_KQ < k_VKQ_sup) {
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
}
}
KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
@@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float KQ_sum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) {
KQ_sum_add += val;
}
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val;
tmp[i0/(np*warp_size)][jc1] = val;
}
KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
@@ -975,26 +976,6 @@ static __global__ void flash_attn_tile(
}
}
if (gridDim.y == 1) {
#pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) {
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
#pragma unroll
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
}
#else
const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
#pragma unroll
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
}
#endif // FAST_FP16_AVAILABLE
}
}
// Write back results:
#pragma unroll
for (int jc0 = 0; jc0 < cpw; ++jc0) {
@@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile(
return;
}
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
#ifdef FAST_FP16_AVAILABLE
@@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile(
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
tmp[i1].x *= scale;
tmp[i1].y *= scale;
}
if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
@@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile(
#pragma unroll
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
}
ggml_cuda_memcpy_1<cpy_ne_D*4>(
&dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
&VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
+56
View File
@@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_SUM);
char base[256];
char name[256];
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
@@ -1482,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_SGD);
char base[256];
char name[256];
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
+3
View File
@@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
@@ -134,6 +135,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
+4
View File
@@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
@@ -798,6 +799,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false;
};
}
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return has_simdgroup_reduction;
default:
return false;
}
+12
View File
@@ -544,6 +544,10 @@ typedef struct{
float limit;
} ggml_metal_kargs_glu;
typedef struct {
uint64_t np;
} ggml_metal_kargs_sum;
typedef struct {
int64_t ne00;
int64_t ne01;
@@ -773,4 +777,12 @@ typedef struct {
uint64_t nb01;
} ggml_metal_kargs_argmax;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_adamw;
typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_sgd;
#endif // GGML_METAL_IMPL
+106
View File
@@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_glu(ctx, idx);
} break;
case GGML_OP_SUM:
{
n_fuse = ggml_metal_op_sum(ctx, idx);
} break;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{
@@ -410,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_argmax(ctx, idx);
} break;
case GGML_OP_OPT_STEP_ADAMW:
{
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
} break;
case GGML_OP_OPT_STEP_SGD:
{
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@@ -840,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
ggml_metal_kargs_sum args = {
/*.np =*/ n,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -3401,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_adamw args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
const int64_t np = ggml_nelements(op->src[0]);
ggml_metal_kargs_opt_step_sgd args = {
/*.np =*/ np,
};
int ida = 0;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
const int64_t n = (np + nth - 1) / nth;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
return 1;
}
+3
View File
@@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
@@ -78,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus
}
+66
View File
@@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
}
}
kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
if (tiitg != 0) {
return;
}
float acc = 0.0f;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
}
dst[0] = acc;
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
@@ -8754,3 +8772,51 @@ kernel void kernel_pool_2d_avg_f32(
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
+2
View File
@@ -18,6 +18,7 @@
#include "concat.hpp"
#include "conv.hpp"
#include "convert.hpp"
#include "count-equal.hpp"
#include "cpy.hpp"
#include "dequantize.hpp"
#include "dmmv.hpp"
@@ -28,6 +29,7 @@
#include "mmvq.hpp"
#include "norm.hpp"
#include "outprod.hpp"
#include "pad.hpp"
#include "quantize.hpp"
#include "quants.hpp"
#include "rope.hpp"
-9
View File
@@ -303,10 +303,6 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
@@ -332,11 +328,6 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_sub(ctx, dst);
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_count_equal(ctx, dst);
}
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_mul(ctx, dst);
-6
View File
@@ -16,12 +16,6 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
return a - b;
}
static __dpct_inline__ float op_count_equal(const float a, const float b) {
return (a == b) ? 1.0f : 0.0f;
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
static __dpct_inline__ float op_mul(const float a, const float b) {
return a * b;
}
+2 -1
View File
@@ -195,7 +195,8 @@ struct optimize_feature {
struct sycl_device_info {
int cc; // compute capability
// int nsm; // number of streaming multiprocessors
int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
// number of compute units on a SYCL device.
// size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool vmm; // virtual memory support
+79
View File
@@ -0,0 +1,79 @@
#include "count-equal.hpp"
#include <cstdint>
template <typename T>
static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
int64_t *__restrict__ dst, const int64_t dk,
const int64_t k) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk;
const int64_t i1 = sycl::min(i0 + dk, k);
int nequal = 0;
for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) {
const T xi = x[i];
const T yi = y[i];
nequal += xi == yi;
}
nequal = warp_reduce_sum(nequal);
if (item_ct1.get_local_id(2) != 0) {
return;
}
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
(int *)dst, nequal);
}
void ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT( dst->type == GGML_TYPE_I64);
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
int64_t * dst_d = (int64_t *) dst->data;
dpct::queue_ptr stream = ctx.stream();
const int id = get_current_device_id();
const int nsm = ggml_sycl_info().devices[id].nsm;
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
const int64_t dne =
GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE);
SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst))));
const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
const dpct::dim3 block_nums(
std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) /
SYCL_COUNT_EQUAL_CHUNK_SIZE),
1, 1);
switch (src0->type) {
case GGML_TYPE_I32: {
const int *src0_d = (const int *)src0->data;
const int *src1_d = (const int *)src1->data;
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
count_equal(src0_d, src1_d, dst_d, dne, ne);
GGML_UNUSED(item_ct1);
});
} break;
default:
GGML_ASSERT(false);
break;
}
}
+9
View File
@@ -0,0 +1,9 @@
#ifndef GGML_SYCL_COUNT_EQUAL_HPP
#define GGML_SYCL_COUNT_EQUAL_HPP
#include "common.hpp"
#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif //GGML_SYCL_COUNT_EQUAL_HPP
-78
View File
@@ -328,26 +328,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
}
template <typename T>
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
const sycl::nd_item<3> &item_ct1) {
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
if (nidx >= ne0) {
return;
}
// operation
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
item_ct1.get_group(0) * ne00 * ne01;
dst[offset_dst] = x[offset_src];
} else {
dst[offset_dst] = static_cast<T>(0.0f);
}
}
template<typename T>
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
const sycl::nd_item<1> &item_ct1) {
@@ -431,18 +411,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
});
}
template<typename T>
static void pad_sycl(const T *x, T *dst, const int ne00,
const int ne01, const int ne02, const int ne0,
const int ne1, const int ne2, queue_ptr stream) {
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
sycl::range<3> gridDim(ne2, ne1, num_blocks);
stream->parallel_for(
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
}
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
@@ -596,40 +564,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
}
}
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
break;
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
}
}
} // namespace ggml_sycl_detail
@@ -919,14 +853,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te
});
}
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
queue_ptr stream) {
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
});
}
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
float min_val;
float max_val;
@@ -1119,10 +1045,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_upscale(ctx, dst);
}
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_pad(ctx, dst);
}
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
-2
View File
@@ -67,8 +67,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+63 -40
View File
@@ -85,9 +85,11 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.devices[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
info.devices[i].nsm = prop.get_max_compute_units();
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
info.devices[i].smpbo = prop.get_local_mem_size();
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
}
for (int id = 0; id < info.device_count; ++id) {
@@ -1512,60 +1514,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
template <ggml_sort_order order>
__dpct_inline__ static void
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
uint8_t *dpct_local) {
// bitonic sort
int col = item_ct1.get_local_id(2);
int col_index = item_ct1.get_local_id(2);
int row = item_ct1.get_group(1);
if (col >= ncols_pad) {
return;
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
if (col >= ncols_pad) {
return;
}
}
const float * x_row = x + row * ncols;
auto dst_row = (int *)dpct_local;
// initialize indices
dst_row[col] = col;
for (int i=0;i<tasks_per_thread;i++){
int col = col_index*tasks_per_thread+i;
dst_row[col] = col;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols &&
(order == GGML_SORT_ORDER_ASC
? x_row[dst_row[col]] > x_row[dst_row[ixj]]
: x_row[dst_row[col]] <
x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols &&
(order == GGML_SORT_ORDER_ASC
? x_row[dst_row[col]] < x_row[dst_row[ixj]]
: x_row[dst_row[col]] >
x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
}
/*
DPCT1118:1: SYCL group functions and algorithms must be encountered
in converged control flow. You may need to adjust the code.
*/
item_ct1.barrier(sycl::access::fence_space::local_space);
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
}
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
const sycl::nd_item<3> &item_ct1) {
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -1738,11 +1750,20 @@ static int next_power_of_2(int x) {
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
const int nrows, ggml_sort_order order,
queue_ptr stream) {
queue_ptr stream, int device) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const sycl::range<3> block_dims(1, 1, ncols_pad);
int nth = 1;
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
while (nth < ncols_pad && nth < max_block_size)
nth *= 2;
if (nth > max_block_size)
nth = max_block_size;
const int tasks_per_thread = ncols_pad / nth;
const sycl::range<3> block_dims(1, 1, nth);
const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
@@ -1755,8 +1776,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
});
});
@@ -1769,8 +1791,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
});
});
@@ -2142,7 +2165,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
main_stream, ctx.device);
}
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -4413,8 +4437,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ACC:
return true;
case GGML_OP_PAD:
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
return ggml_is_contiguous(op->src[0]);
case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:
+97
View File
@@ -0,0 +1,97 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//#include "common.hpp"
#include "pad.hpp"
static void pad_f32(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
int i0 = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
int i1 = item_ct1.get_group(1);
int i2 = item_ct1.get_group(0) % ne2;
int i3 = item_ct1.get_group(0) / ne2;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
// operation
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
(i1 >= lp1 && i1 < ne1 - rp1) &&
(i2 >= lp2 && i2 < ne2 - rp2) &&
(i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t i00 = i0 - lp0;
const int64_t i01 = i1 - lp1;
const int64_t i02 = i2 - lp2;
const int64_t i03 = i3 - lp3;
const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +
i02 * (ne00 * ne01) + i01 * ne00 + i00;
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] = 0.0f;
}
}
static void pad_f32_sycl(const float *src, float *dst, const int lp0,
const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3,
const int rp3, const int ne0, const int ne1,
const int ne2, const int ne3,
dpct::queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);
stream->parallel_for(
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
ne2, ne3);
});
}
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
pad_f32_sycl(src0_d, dst_d,
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_pad(ctx, dst);
}
+24
View File
@@ -0,0 +1,24 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_PAD_HPP
#define GGML_SYCL_PAD_HPP
#include "common.hpp"
#define SYCL_PAD_BLOCK_SIZE 256
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_PAD_HPP
+5 -1
View File
@@ -140,7 +140,11 @@ uint32_t llama_hparams::n_embd_s() const {
}
bool llama_hparams::is_recurrent(uint32_t il) const {
return recurrent_layer_arr[il];
if (il < n_layer) {
return recurrent_layer_arr[il];
}
GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer);
}
uint32_t llama_hparams::n_pos_per_embd() const {
+58
View File
@@ -524,6 +524,64 @@ static void test_json_with_dumped_args() {
R"({"foo": "bar", "args": {"arg1": [)",
R"({"foo":"bar","args":"{\"arg1\":["})"
);
// Unicode tests
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u0)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u0"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u00)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u00"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u000)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u000"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\u0000)",
R"({"foo":"bar","args":"{\"arg1\":\"\\u0000"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud8)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud8"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud80)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud80"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\u)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\u"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\ud)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\ud"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\udc)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\udc0)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc0"})"
);
test_with_args(
R"({"foo": "bar", "args": {"arg1": "\ud800\udc00)",
R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc00"})"
);
}
static void test_positions() {
+51 -1
View File
@@ -58,7 +58,7 @@ static void test_json_healing() {
for (const auto & input : inputs) {
common_json out;
assert_equals(true, common_json_parse(input, "$foo", out));
assert_equals<std::string>(expected, out.json.dump());
assert_equals<std::string>(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true));
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
}
};
@@ -228,6 +228,56 @@ static void test_json_healing() {
R"({"key":"$foo"})",
R"(:"$foo)"
);
// Test unicode escape sequences
test(
{
R"({"a":"\u)",
},
R"({"a":"\u0000$foo"})",
R"(0000$foo)"
);
test(
{
R"({"a":"\u00)",
},
R"({"a":"\u0000$foo"})",
R"(00$foo)"
);
test(
{
R"({"a":"\ud300)",
},
R"({"a":"\ud300$foo"})",
R"($foo)"
);
test(
{
R"({"a":"\ud800)",
},
R"({"a":"\ud800\udc00$foo"})",
R"(\udc00$foo)"
);
test(
{
R"({"a":"\ud800\)",
},
R"({"a":"\ud800\udc00$foo"})",
R"(udc00$foo)"
);
test(
{
R"({"a":"\ud800\u)",
},
R"({"a":"\ud800\udc00$foo"})",
R"(dc00$foo)"
);
test(
{
R"({"a":"\ud800\udc00)",
},
R"({"a":"\ud800\udc00$foo"})",
R"($foo)"
);
}
int main() {
Binary file not shown.
+69
View File
@@ -50,6 +50,7 @@
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3",
"playwright": "^1.53.0",
"prettier": "^3.4.2",
@@ -66,6 +67,7 @@
"tw-animate-css": "^1.3.5",
"typescript": "^5.0.0",
"typescript-eslint": "^8.20.0",
"unified": "^11.0.5",
"uuid": "^13.0.0",
"vite": "^7.0.4",
"vite-plugin-devtools-json": "^0.2.0",
@@ -2128,6 +2130,66 @@
"node": ">=14.0.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/core": {
"version": "1.4.3",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"@emnapi/wasi-threads": "1.0.2",
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/runtime": {
"version": "1.4.3",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/wasi-threads": {
"version": "1.0.2",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@napi-rs/wasm-runtime": {
"version": "0.2.11",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"@emnapi/core": "^1.4.3",
"@emnapi/runtime": "^1.4.3",
"@tybys/wasm-util": "^0.9.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@tybys/wasm-util": {
"version": "0.9.0",
"dev": true,
"inBundle": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/tslib": {
"version": "2.8.0",
"dev": true,
"inBundle": true,
"license": "0BSD",
"optional": true
},
"node_modules/@tailwindcss/oxide-win32-arm64-msvc": {
"version": "4.1.11",
"resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.11.tgz",
@@ -4946,6 +5008,13 @@
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/mdast": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/mdast/-/mdast-3.0.0.tgz",
"integrity": "sha512-xySmf8g4fPKMeC07jXGz971EkLbWAJ83s4US2Tj9lEdnZ142UP5grN73H1Xd3HzrdbU5o9GYYP/y8F9ZSwLE9g==",
"dev": true,
"license": "MIT"
},
"node_modules/mdast-util-find-and-replace": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/mdast-util-find-and-replace/-/mdast-util-find-and-replace-3.0.2.tgz",
+2
View File
@@ -52,6 +52,7 @@
"eslint-plugin-svelte": "^3.0.0",
"fflate": "^0.8.2",
"globals": "^16.0.0",
"mdast": "^3.0.0",
"mdsvex": "^0.12.3",
"playwright": "^1.53.0",
"prettier": "^3.4.2",
@@ -68,6 +69,7 @@
"tw-animate-css": "^1.3.5",
"typescript": "^5.0.0",
"typescript-eslint": "^8.20.0",
"unified": "^11.0.5",
"uuid": "^13.0.0",
"vite": "^7.0.4",
"vite-plugin-devtools-json": "^0.2.0",
@@ -7,6 +7,7 @@
ChatMessages,
ChatProcessingInfo,
EmptyFileAlertDialog,
ChatErrorDialog,
ServerErrorSplash,
ServerInfo,
ServerLoadingSplash,
@@ -22,10 +23,11 @@
activeMessages,
activeConversation,
deleteConversation,
dismissErrorDialog,
errorDialog,
isLoading,
sendMessage,
stopGeneration,
setMaxContextError
stopGeneration
} from '$lib/stores/chat.svelte';
import {
supportsVision,
@@ -34,7 +36,6 @@
serverWarning,
serverStore
} from '$lib/stores/server.svelte';
import { contextService } from '$lib/services';
import { parseFilesToMessageExtras } from '$lib/utils/convert-files-to-extra';
import { isFileTypeSupported } from '$lib/utils/file-type';
import { filterFilesByModalities } from '$lib/utils/modality-file-validation';
@@ -79,6 +80,7 @@
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
);
let activeErrorDialog = $derived(errorDialog());
let isServerLoading = $derived(serverLoading());
async function handleDeleteConfirm() {
@@ -105,6 +107,12 @@
}
}
function handleErrorDialogOpenChange(open: boolean) {
if (!open) {
dismissErrorDialog();
}
}
function handleDragOver(event: DragEvent) {
event.preventDefault();
}
@@ -183,21 +191,6 @@
const extras = result?.extras;
// Check context limit using real-time slots data
const contextCheck = await contextService.checkContextLimit();
if (contextCheck && contextCheck.wouldExceed) {
const errorMessage = contextService.getContextErrorMessage(contextCheck);
setMaxContextError({
message: errorMessage,
estimatedTokens: contextCheck.currentUsage,
maxContext: contextCheck.maxContext
});
return false;
}
// Enable autoscroll for user-initiated message sending
userScrolledUp = false;
autoScrollEnabled = true;
@@ -461,6 +454,13 @@
}}
/>
<ChatErrorDialog
message={activeErrorDialog?.message ?? ''}
onOpenChange={handleErrorDialogOpenChange}
open={Boolean(activeErrorDialog)}
type={activeErrorDialog?.type ?? 'server'}
/>
<style>
.conversation-chat-form {
position: relative;
@@ -0,0 +1,60 @@
<script lang="ts">
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import { AlertTriangle, TimerOff } from '@lucide/svelte';
interface Props {
open: boolean;
type: 'timeout' | 'server';
message: string;
onOpenChange?: (open: boolean) => void;
}
let { open = $bindable(), type, message, onOpenChange }: Props = $props();
const isTimeout = $derived(type === 'timeout');
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
const description = $derived(
isTimeout
? 'The request did not receive a response from the server before timing out.'
: 'The server responded with an error message. Review the details below.'
);
const iconClass = $derived(isTimeout ? 'text-destructive' : 'text-amber-500');
const badgeClass = $derived(
isTimeout
? 'border-destructive/40 bg-destructive/10 text-destructive'
: 'border-amber-500/40 bg-amber-500/10 text-amber-600 dark:text-amber-400'
);
function handleOpenChange(newOpen: boolean) {
open = newOpen;
onOpenChange?.(newOpen);
}
</script>
<AlertDialog.Root {open} onOpenChange={handleOpenChange}>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title class="flex items-center gap-2">
{#if isTimeout}
<TimerOff class={`h-5 w-5 ${iconClass}`} />
{:else}
<AlertTriangle class={`h-5 w-5 ${iconClass}`} />
{/if}
{title}
</AlertDialog.Title>
<AlertDialog.Description>
{description}
</AlertDialog.Description>
</AlertDialog.Header>
<div class={`rounded-lg border px-4 py-3 text-sm ${badgeClass}`}>
<p class="font-medium">{message}</p>
</div>
<AlertDialog.Footer>
<AlertDialog.Action onclick={() => handleOpenChange(false)}>Close</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>
@@ -1,66 +0,0 @@
<script lang="ts">
import { AlertTriangle } from '@lucide/svelte';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import { maxContextError, clearMaxContextError } from '$lib/stores/chat.svelte';
</script>
<AlertDialog.Root
open={maxContextError() !== null}
onOpenChange={(open) => !open && clearMaxContextError()}
>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title class="flex items-center gap-2">
<AlertTriangle class="h-5 w-5 text-destructive" />
Message Too Long
</AlertDialog.Title>
<AlertDialog.Description>
Your message exceeds the model's context window and cannot be processed.
</AlertDialog.Description>
</AlertDialog.Header>
{#if maxContextError()}
<div class="space-y-3 text-sm">
<div class="rounded-lg bg-muted p-3">
<div class="mb-2 font-medium">Token Usage:</div>
<div class="space-y-1 text-muted-foreground">
<div>
Estimated tokens:
<span class="font-mono">
{maxContextError()?.estimatedTokens.toLocaleString()}
</span>
</div>
<div>
Context window:
<span class="font-mono">
{maxContextError()?.maxContext.toLocaleString()}
</span>
</div>
</div>
</div>
<div>
<div class="mb-2 font-medium">Suggestions:</div>
<ul class="list-inside list-disc space-y-1 text-muted-foreground">
<li>Shorten your message</li>
<li>Remove some file attachments</li>
<li>Start a new conversation</li>
</ul>
</div>
</div>
{/if}
<AlertDialog.Footer>
<AlertDialog.Action onclick={() => clearMaxContextError()}>Got it</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>
@@ -30,12 +30,11 @@ export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
export { default as ChatSidebarSearch } from './chat/ChatSidebar/ChatSidebarSearch.svelte';
export { default as ChatErrorDialog } from './dialogs/ChatErrorDialog.svelte';
export { default as EmptyFileAlertDialog } from './dialogs/EmptyFileAlertDialog.svelte';
export { default as ConversationTitleUpdateDialog } from './dialogs/ConversationTitleUpdateDialog.svelte';
export { default as MaximumContextAlertDialog } from './dialogs/MaximumContextAlertDialog.svelte';
export { default as KeyboardShortcutInfo } from './misc/KeyboardShortcutInfo.svelte';
export { default as MarkdownContent } from './misc/MarkdownContent.svelte';
@@ -14,6 +14,7 @@
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
import githubLightCss from 'highlight.js/styles/github.css?inline';
import { mode } from 'mode-watcher';
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
interface Props {
content: string;
@@ -50,36 +51,59 @@
.use(remarkGfm) // GitHub Flavored Markdown
.use(remarkMath) // Parse $inline$ and $$block$$ math
.use(remarkBreaks) // Convert line breaks to <br>
.use(remarkRehype) // Convert to rehype (HTML AST)
.use(remarkLiteralHtml) // Treat raw HTML as literal text with preserved indentation
.use(remarkRehype) // Convert Markdown AST to rehype
.use(rehypeKatex) // Render math using KaTeX
.use(rehypeHighlight) // Add syntax highlighting
.use(rehypeStringify); // Convert to HTML string
});
function enhanceLinks(html: string): string {
if (!html.includes('<a')) {
return html;
}
const tempDiv = document.createElement('div');
tempDiv.innerHTML = html;
// Make all links open in new tabs
const linkElements = tempDiv.querySelectorAll('a[href]');
let mutated = false;
for (const link of linkElements) {
const target = link.getAttribute('target');
const rel = link.getAttribute('rel');
if (target !== '_blank' || rel !== 'noopener noreferrer') {
mutated = true;
}
link.setAttribute('target', '_blank');
link.setAttribute('rel', 'noopener noreferrer');
}
return tempDiv.innerHTML;
return mutated ? tempDiv.innerHTML : html;
}
function enhanceCodeBlocks(html: string): string {
if (!html.includes('<pre')) {
return html;
}
const tempDiv = document.createElement('div');
tempDiv.innerHTML = html;
const preElements = tempDiv.querySelectorAll('pre');
let mutated = false;
for (const [index, pre] of Array.from(preElements).entries()) {
const codeElement = pre.querySelector('code');
if (!codeElement) continue;
if (!codeElement) {
continue;
}
mutated = true;
let language = 'text';
const classList = Array.from(codeElement.classList);
@@ -127,7 +151,7 @@
pre.parentNode?.replaceChild(wrapper, pre);
}
return tempDiv.innerHTML;
return mutated ? tempDiv.innerHTML : html;
}
async function processMarkdown(text: string): Promise<string> {
@@ -0,0 +1,15 @@
export const LINE_BREAK = /\r?\n/;
export const PHRASE_PARENTS = new Set([
'paragraph',
'heading',
'emphasis',
'strong',
'delete',
'link',
'linkReference',
'tableCell'
]);
export const NBSP = '\u00a0';
export const TAB_AS_SPACES = NBSP.repeat(4);
@@ -0,0 +1,121 @@
import type { Plugin } from 'unified';
import { visit } from 'unist-util-visit';
import type { Break, Content, Paragraph, PhrasingContent, Root, Text } from 'mdast';
import { LINE_BREAK, NBSP, PHRASE_PARENTS, TAB_AS_SPACES } from '$lib/constants/literal-html';
/**
* remark plugin that rewrites raw HTML nodes into plain-text equivalents.
*
* remark parses inline HTML into `html` nodes even when we do not want to render
* them. We turn each of those nodes into regular text (plus `<br>` break markers)
* so the downstream rehype pipeline escapes the characters instead of executing
* them. Leading spaces and tab characters are converted to nonbreaking spaces to
* keep indentation identical to the original author input.
*/
function preserveIndent(line: string): string {
let index = 0;
let output = '';
while (index < line.length) {
const char = line[index];
if (char === ' ') {
output += NBSP;
index += 1;
continue;
}
if (char === '\t') {
output += TAB_AS_SPACES;
index += 1;
continue;
}
break;
}
return output + line.slice(index);
}
function createLiteralChildren(value: string): PhrasingContent[] {
const lines = value.split(LINE_BREAK);
const nodes: PhrasingContent[] = [];
for (const [lineIndex, rawLine] of lines.entries()) {
if (lineIndex > 0) {
nodes.push({ type: 'break' } as Break as unknown as PhrasingContent);
}
nodes.push({
type: 'text',
value: preserveIndent(rawLine)
} as Text as unknown as PhrasingContent);
}
if (!nodes.length) {
nodes.push({ type: 'text', value: '' } as Text as unknown as PhrasingContent);
}
return nodes;
}
export const remarkLiteralHtml: Plugin<[], Root> = () => {
return (tree) => {
visit(tree, 'html', (node, index, parent) => {
if (!parent || typeof index !== 'number') {
return;
}
const replacement = createLiteralChildren(node.value);
if (!PHRASE_PARENTS.has(parent.type as string)) {
const paragraph: Paragraph = {
type: 'paragraph',
children: replacement as Paragraph['children'],
data: { literalHtml: true }
};
const siblings = parent.children as unknown as Content[];
siblings.splice(index, 1, paragraph as unknown as Content);
if (index > 0) {
const previous = siblings[index - 1] as Paragraph | undefined;
if (
previous?.type === 'paragraph' &&
(previous.data as { literalHtml?: boolean } | undefined)?.literalHtml
) {
const prevChildren = previous.children as unknown as PhrasingContent[];
if (prevChildren.length) {
const lastChild = prevChildren[prevChildren.length - 1];
if (lastChild.type !== 'break') {
prevChildren.push({
type: 'break'
} as Break as unknown as PhrasingContent);
}
}
prevChildren.push(...(paragraph.children as unknown as PhrasingContent[]));
siblings.splice(index, 1);
return index;
}
}
return index + 1;
}
(parent.children as unknown as PhrasingContent[]).splice(
index,
1,
...(replacement as unknown as PhrasingContent[])
);
return index + replacement.length;
});
};
};
+26 -63
View File
@@ -13,7 +13,7 @@ import { slotsService } from './slots';
* - Manages streaming and non-streaming response parsing
* - Provides request abortion capabilities
* - Converts database messages to API format
* - Handles error translation and context detection
* - Handles error translation for server responses
*
* - **ChatStore**: Stateful orchestration and UI state management
* - Uses ChatService for all AI model communication
@@ -26,7 +26,6 @@ import { slotsService } from './slots';
* - Streaming response handling with real-time callbacks
* - Reasoning content extraction and processing
* - File attachment processing (images, PDFs, audio, text)
* - Context error detection and reporting
* - Request lifecycle management (abort, cleanup)
*/
export class ChatService {
@@ -209,10 +208,13 @@ export class ChatService {
userFriendlyError = new Error(
'Unable to connect to server - please check if the server is running'
);
userFriendlyError.name = 'NetworkError';
} else if (error.message.includes('ECONNREFUSED')) {
userFriendlyError = new Error('Connection refused - server may be offline');
userFriendlyError.name = 'NetworkError';
} else if (error.message.includes('ETIMEDOUT')) {
userFriendlyError = new Error('Request timeout - server may be overloaded');
userFriendlyError = new Error('Request timed out - the server took too long to respond');
userFriendlyError.name = 'TimeoutError';
} else {
userFriendlyError = error;
}
@@ -262,6 +264,7 @@ export class ChatService {
let fullReasoningContent = '';
let hasReceivedData = false;
let lastTimings: ChatMessageTimings | undefined;
let streamFinished = false;
try {
let chunk = '';
@@ -277,18 +280,8 @@ export class ChatService {
if (line.startsWith('data: ')) {
const data = line.slice(6);
if (data === '[DONE]') {
if (!hasReceivedData && aggregatedContent.length === 0) {
const contextError = new Error(
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
);
contextError.name = 'ContextError';
onError?.(contextError);
return;
}
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
return;
streamFinished = true;
continue;
}
try {
@@ -326,13 +319,13 @@ export class ChatService {
}
}
if (!hasReceivedData && aggregatedContent.length === 0) {
const contextError = new Error(
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
);
contextError.name = 'ContextError';
onError?.(contextError);
return;
if (streamFinished) {
if (!hasReceivedData && aggregatedContent.length === 0) {
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
}
} catch (error) {
const err = error instanceof Error ? error : new Error('Stream error');
@@ -368,12 +361,8 @@ export class ChatService {
const responseText = await response.text();
if (!responseText.trim()) {
const contextError = new Error(
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
);
contextError.name = 'ContextError';
onError?.(contextError);
throw contextError;
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
const data: ApiChatCompletionResponse = JSON.parse(responseText);
@@ -385,22 +374,14 @@ export class ChatService {
}
if (!content.trim()) {
const contextError = new Error(
'The request exceeds the available context size. Try increasing the context size or enable context shift.'
);
contextError.name = 'ContextError';
onError?.(contextError);
throw contextError;
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
onComplete?.(content, reasoningContent);
return content;
} catch (error) {
if (error instanceof Error && error.name === 'ContextError') {
throw error;
}
const err = error instanceof Error ? error : new Error('Parse error');
onError?.(err);
@@ -594,37 +575,19 @@ export class ChatService {
const errorText = await response.text();
const errorData: ApiErrorResponse = JSON.parse(errorText);
if (errorData.error?.type === 'exceed_context_size_error') {
const contextError = errorData.error as ApiContextSizeError;
const error = new Error(contextError.message);
error.name = 'ContextError';
// Attach structured context information
(
error as Error & {
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
}
).contextInfo = {
promptTokens: contextError.n_prompt_tokens,
maxContext: contextError.n_ctx,
estimatedTokens: contextError.n_prompt_tokens
};
return error;
}
// Fallback for other error types
const message = errorData.error?.message || 'Unknown server error';
return new Error(message);
const error = new Error(message);
error.name = response.status === 400 ? 'ServerError' : 'HttpError';
return error;
} catch {
// If we can't parse the error response, return a generic error
return new Error(`Server error (${response.status}): ${response.statusText}`);
const fallback = new Error(`Server error (${response.status}): ${response.statusText}`);
fallback.name = 'HttpError';
return fallback;
}
}
/**
* Updates the processing state with timing information from the server response
* @param timings - Timing data from the API response
* @param promptProgress - Progress data from the API response
*/
private updateProcessingState(
timings?: ChatMessageTimings,
promptProgress?: ChatMessagePromptProgress
@@ -1,102 +0,0 @@
import { slotsService } from './slots';
export interface ContextCheckResult {
wouldExceed: boolean;
currentUsage: number;
maxContext: number;
availableTokens: number;
reservedTokens: number;
}
/**
* ContextService - Context window management and limit checking
*
* This service provides context window monitoring and limit checking using real-time
* server data from the slots service. It helps prevent context overflow by tracking
* current usage and calculating available space for new content.
*
* **Architecture & Relationships:**
* - **ContextService** (this class): Context limit monitoring
* - Uses SlotsService for real-time context usage data
* - Calculates available tokens with configurable reserves
* - Provides context limit checking and error messaging
* - Helps prevent context window overflow
*
* - **SlotsService**: Provides current context usage from server slots
* - **ChatStore**: Uses context checking before sending messages
* - **UI Components**: Display context usage warnings and limits
*
* **Key Features:**
* - **Real-time Context Checking**: Uses live server data for accuracy
* - **Token Reservation**: Reserves tokens for response generation
* - **Limit Detection**: Prevents context window overflow
* - **Usage Reporting**: Detailed context usage statistics
* - **Error Messaging**: User-friendly context limit messages
* - **Configurable Reserves**: Adjustable token reservation for responses
*
* **Context Management:**
* - Monitors current context usage from active slots
* - Calculates available space considering reserved tokens
* - Provides early warning before context limits are reached
* - Helps optimize conversation length and content
*/
export class ContextService {
private reserveTokens: number;
constructor(reserveTokens = 512) {
this.reserveTokens = reserveTokens;
}
/**
* Checks if the context limit would be exceeded
*
* @returns {Promise<ContextCheckResult | null>} Promise that resolves to the context check result or null if an error occurs
*/
async checkContextLimit(): Promise<ContextCheckResult | null> {
try {
const currentState = await slotsService.getCurrentState();
if (!currentState) {
return null;
}
const maxContext = currentState.contextTotal;
const currentUsage = currentState.contextUsed;
const availableTokens = maxContext - currentUsage - this.reserveTokens;
const wouldExceed = availableTokens <= 0;
return {
wouldExceed,
currentUsage,
maxContext,
availableTokens: Math.max(0, availableTokens),
reservedTokens: this.reserveTokens
};
} catch (error) {
console.warn('Error checking context limit:', error);
return null;
}
}
/**
* Returns a formatted error message for context limit exceeded
*
* @param {ContextCheckResult} result - Context check result
* @returns {string} Formatted error message
*/
getContextErrorMessage(result: ContextCheckResult): string {
const usagePercent = Math.round((result.currentUsage / result.maxContext) * 100);
return `Context window is nearly full. Current usage: ${result.currentUsage.toLocaleString()}/${result.maxContext.toLocaleString()} tokens (${usagePercent}%). Available space: ${result.availableTokens.toLocaleString()} tokens (${result.reservedTokens} reserved for response).`;
}
/**
* Sets the number of tokens to reserve for response generation
*
* @param {number} tokens - Number of tokens to reserve
*/
setReserveTokens(tokens: number): void {
this.reserveTokens = tokens;
}
}
export const contextService = new ContextService();
@@ -1,3 +1,2 @@
export { chatService } from './chat';
export { contextService } from './context';
export { slotsService } from './slots';
+34 -106
View File
@@ -39,7 +39,6 @@ import type { ExportedConversations } from '$lib/types/database';
* - Conversation branching for exploring different response paths
* - Streaming AI responses with real-time content updates
* - File attachment support (images, PDFs, text files, audio)
* - Context window management with error recovery
* - Partial response saving when generation is interrupted
* - Message editing with automatic response regeneration
*/
@@ -48,11 +47,9 @@ class ChatStore {
activeMessages = $state<DatabaseMessage[]>([]);
conversations = $state<DatabaseConversation[]>([]);
currentResponse = $state('');
errorDialogState = $state<{ type: 'timeout' | 'server'; message: string } | null>(null);
isInitialized = $state(false);
isLoading = $state(false);
maxContextError = $state<{ message: string; estimatedTokens: number; maxContext: number } | null>(
null
);
titleUpdateConfirmationCallback?: (currentTitle: string, newTitle: string) => Promise<boolean>;
constructor() {
@@ -69,8 +66,6 @@ class ChatStore {
try {
await this.loadConversations();
this.maxContextError = null;
this.isInitialized = true;
} catch (error) {
console.error('Failed to initialize chat store:', error);
@@ -99,8 +94,6 @@ class ChatStore {
this.activeConversation = conversation;
this.activeMessages = [];
this.maxContextError = null;
await goto(`#/chat/${conversation.id}`);
return conversation.id;
@@ -133,8 +126,6 @@ class ChatStore {
this.activeMessages = await DatabaseStore.getConversationMessages(convId);
}
this.maxContextError = null;
return true;
} catch (error) {
console.error('Failed to load conversation:', error);
@@ -418,56 +409,6 @@ class ChatStore {
return;
}
if (error.name === 'ContextError') {
console.warn('Context error detected:', error.message);
this.isLoading = false;
this.currentResponse = '';
const messageIndex = this.activeMessages.findIndex(
(m: DatabaseMessage) => m.id === assistantMessage.id
);
if (messageIndex !== -1) {
this.activeMessages.splice(messageIndex, 1);
DatabaseStore.deleteMessage(assistantMessage.id).catch(console.error);
}
// Use structured context info from new exceed_context_size_error format if available
const contextInfo = (
error as Error & {
contextInfo?: { promptTokens: number; maxContext: number; estimatedTokens: number };
}
).contextInfo;
let estimatedTokens = 0;
let maxContext = serverStore.serverProps?.default_generation_settings.n_ctx || 8192;
if (contextInfo) {
// Use precise token counts from server response
estimatedTokens = contextInfo.promptTokens;
maxContext = contextInfo.maxContext;
} else {
// Fallback to estimation for older error format
try {
// Rough estimation: ~4 characters per token
const messageContent = JSON.stringify(messages);
estimatedTokens = Math.ceil(messageContent.length / 4);
} catch {
estimatedTokens = 0;
}
}
this.maxContextError = {
message: error.message,
estimatedTokens,
maxContext
};
if (onError) {
onError(error);
}
return;
}
console.error('Streaming error:', error);
this.isLoading = false;
this.currentResponse = '';
@@ -477,9 +418,19 @@ class ChatStore {
);
if (messageIndex !== -1) {
this.activeMessages[messageIndex].content = `Error: ${error.message}`;
const [failedMessage] = this.activeMessages.splice(messageIndex, 1);
if (failedMessage) {
DatabaseStore.deleteMessage(failedMessage.id).catch((cleanupError) => {
console.error('Failed to remove assistant message after error:', cleanupError);
});
}
}
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
this.showErrorDialog(dialogType, error.message);
if (onError) {
onError(error);
}
@@ -487,6 +438,14 @@ class ChatStore {
});
}
private showErrorDialog(type: 'timeout' | 'server', message: string): void {
this.errorDialogState = { type, message };
}
dismissErrorDialog(): void {
this.errorDialogState = null;
}
/**
* Checks if an error is an abort error (user cancelled operation)
* @param error - The error to check
@@ -574,6 +533,7 @@ class ChatStore {
return;
}
this.errorDialogState = null;
this.isLoading = true;
this.currentResponse = '';
@@ -603,37 +563,23 @@ class ChatStore {
const conversationContext = this.activeMessages.slice(0, -1);
await this.streamChatCompletion(
conversationContext,
assistantMessage,
undefined,
(error: Error) => {
if (error.name === 'ContextError' && userMessage) {
const userMessageIndex = this.findMessageIndex(userMessage.id);
if (userMessageIndex !== -1) {
this.activeMessages.splice(userMessageIndex, 1);
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
}
}
}
);
await this.streamChatCompletion(conversationContext, assistantMessage);
} catch (error) {
if (this.isAbortError(error)) {
this.isLoading = false;
return;
}
if (error instanceof Error && error.name === 'ContextError' && userMessage) {
const userMessageIndex = this.findMessageIndex(userMessage.id);
if (userMessageIndex !== -1) {
this.activeMessages.splice(userMessageIndex, 1);
DatabaseStore.deleteMessage(userMessage.id).catch(console.error);
}
}
console.error('Failed to send message:', error);
this.isLoading = false;
if (!this.errorDialogState) {
if (error instanceof Error) {
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
this.showErrorDialog(dialogType, error.message);
} else {
this.showErrorDialog('server', 'Unknown error occurred while sending message');
}
}
}
}
@@ -662,24 +608,6 @@ class ChatStore {
this.currentResponse = '';
}
/**
* Clears the max context error state
* Removes any displayed context limit warnings
*/
clearMaxContextError(): void {
this.maxContextError = null;
}
/**
* Sets the max context error state
* @param error - The context error details or null to clear
*/
setMaxContextError(
error: { message: string; estimatedTokens: number; maxContext: number } | null
): void {
this.maxContextError = error;
}
/**
* Saves partial response if generation was interrupted
* Preserves user's partial content and timing data when generation is stopped early
@@ -1250,7 +1178,6 @@ class ChatStore {
this.activeMessages = [];
this.currentResponse = '';
this.isLoading = false;
this.maxContextError = null;
}
/** Refreshes active messages based on currNode after branch navigation */
@@ -1538,6 +1465,7 @@ class ChatStore {
private async generateResponseForMessage(userMessageId: string): Promise<void> {
if (!this.activeConversation) return;
this.errorDialogState = null;
this.isLoading = true;
this.currentResponse = '';
@@ -1584,7 +1512,7 @@ export const activeMessages = () => chatStore.activeMessages;
export const isLoading = () => chatStore.isLoading;
export const currentResponse = () => chatStore.currentResponse;
export const isInitialized = () => chatStore.isInitialized;
export const maxContextError = () => chatStore.maxContextError;
export const errorDialog = () => chatStore.errorDialogState;
export const createConversation = chatStore.createConversation.bind(chatStore);
export const downloadConversation = chatStore.downloadConversation.bind(chatStore);
@@ -1592,9 +1520,9 @@ export const exportAllConversations = chatStore.exportAllConversations.bind(chat
export const importConversations = chatStore.importConversations.bind(chatStore);
export const deleteConversation = chatStore.deleteConversation.bind(chatStore);
export const sendMessage = chatStore.sendMessage.bind(chatStore);
export const dismissErrorDialog = chatStore.dismissErrorDialog.bind(chatStore);
export const gracefulStop = chatStore.gracefulStop.bind(chatStore);
export const clearMaxContextError = chatStore.clearMaxContextError.bind(chatStore);
export const setMaxContextError = chatStore.setMaxContextError.bind(chatStore);
// Branching operations
export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatStore);
@@ -197,7 +197,7 @@ class ServerStore {
errorMessage = 'Server not found - check server address';
isOfflineLikeError = true;
} else if (error.message.includes('ETIMEDOUT')) {
errorMessage = 'Connection timeout - server may be overloaded';
errorMessage = 'Request timed out - the server took too long to respond';
isOfflineLikeError = true;
} else if (error.message.includes('503')) {
errorMessage = 'Server temporarily unavailable - try again shortly';
+1 -7
View File
@@ -1,11 +1,7 @@
<script lang="ts">
import '../app.css';
import { page } from '$app/state';
import {
ChatSidebar,
ConversationTitleUpdateDialog,
MaximumContextAlertDialog
} from '$lib/components/app';
import { ChatSidebar, ConversationTitleUpdateDialog } from '$lib/components/app';
import {
activeMessages,
isLoading,
@@ -145,8 +141,6 @@
<Toaster richColors />
<MaximumContextAlertDialog />
<ConversationTitleUpdateDialog
bind:open={titleUpdateDialogOpen}
currentTitle={titleUpdateCurrentTitle}