Compare commits

...

17 Commits

Author SHA1 Message Date
Georgi Gerganov edfb440a2f server : fix processing of multiple back-to-back mtmd chunks (#21107) 2026-03-28 16:27:36 +02:00
Adrien Gallouët 3d66da1809 ci : gracefully shut down the server (#21110)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-03-28 14:49:57 +01:00
Woof Dog 82b703f8bc Document custom default webui preferences in server README (#19771) 2026-03-28 14:19:16 +01:00
Aleksander Grygier 51a84efc53 webui: Conversation forking + branching improvements (#21021)
* refactor: Make `DialogConfirmation` extensible with children slot

* feat: Add conversation forking logic

* feat: Conversation forking UI

* feat: Update delete/edit dialogs and logic for forks

* refactor: Improve Chat Sidebar UX and add MCP Servers entry

* refactor: Cleanup

* feat: Update message in place when editing leaf nodes

* chore: Cleanup

* chore: Cleanup

* chore: Cleanup

* chore: Cleanup

* chore: Cleanup

* chore: Cleanup

* refactor: Post-review improvements

* chore: update webui build output

* test: Update Storybook test

* chore: update webui build output

* chore: update webui build output
2026-03-28 13:38:15 +01:00
Adrien Gallouët b0f0dd3e51 vendor : update cpp-httplib to 0.40.0 (#21100)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-03-28 08:59:44 +01:00
Ruben Ortlam 0eb4764182 vulkan: add noncontiguous GLU support (#21081)
* vulkan: add noncontiguous GLU support

* fix compile issue
2026-03-28 08:44:56 +01:00
Piotr Wilkin (ilintar) 1f5d15e665 common/parser: fix reasoning whitespace bugs + extra parser tests (#21085)
* fix whitespace reasoning issues + add reconstruction tests

* Proper fix

* fix Nemotron autoparser test expectations to include newline in marker
2026-03-28 07:29:26 +01:00
Sigbjørn Skjæret c46758d28f cli : add /glob command (#21084)
* add /glob command

* output error when max files reached

* support globbing outside curdir
2026-03-28 02:33:04 +01:00
Ts-sound bf934f28db docker : fix and enable ARM64 image build (#20929)
* CI: fix ARM64 image build error & enable compilation

* Update .github/workflows/docker.yml

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

* CI: revert ggml/src/ggml-cpu/CMakeLists.txt

* Update .github/workflows/docker.yml

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

* CI: update runs-on to ubuntu24.04, and update ARM64 build image ( ubuntu_version: "24.04")

* CI: change cpu.Dockerfile gcc to 14;

* CI : cpu.Dockerfile , update pip install .

* Update .github/workflows/docker.yml

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

---------

Co-authored-by: Aaron Teo <taronaeo@gmail.com>
2026-03-28 01:45:09 +01:00
Adrien Gallouët 5c1a7b8355 server : add custom socket options to disable SO_REUSEPORT (#21056)
* server : add custom socket options to disable SO_REUSEPORT

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --reuse-port

    $ strace -e trace=setsockopt,bind build/bin/llama-server -lv 2 --reuse-port
    setsockopt(3, SOL_TCP, TCP_NODELAY, [1], 4) = 0
    setsockopt(3, SOL_SOCKET, SO_REUSEADDR, [1], 4) = 0
    setsockopt(3, SOL_SOCKET, SO_REUSEPORT, [1], 4) = 0
    bind(3, {sa_family=AF_INET, sin_port=htons(8080), sin_addr=inet_addr("127.0.0.1")}, 16) = 0

    $ strace -e trace=setsockopt,bind build/bin/llama-server -lv 2
    setsockopt(3, SOL_TCP, TCP_NODELAY, [1], 4) = 0
    setsockopt(3, SOL_SOCKET, SO_REUSEADDR, [1], 4) = 0
    bind(3, {sa_family=AF_INET, sin_port=htons(8080), sin_addr=inet_addr("127.0.0.1")}, 16) = 0

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update tools/server/README.md (llama-gen-docs)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix windows

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-03-28 01:12:43 +01:00
Aldehir Rojas 59d840209a common : inhibit lazy grammar sampler while reasoning is active (#20970)
* common : inhibit grammar while reasoning budget is active

* cont : update force_pos in accept

* cont : fix tests

* cont : tweak should apply logic

* cont : return early not using grammar sampler

* Add tests

* cont : prevent backend sampling when reasoning budget enabled

* cont : fix typo

---------

Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
2026-03-27 18:30:40 +01:00
Kusha Gharahi ff934e29bc server: Introduce LLAMA_BUILD_WEBUI build flag to allow disabling the embedded web ui (#20158)
* introduce LLAMA_SERVER_NO_WEBUI

* LLAMA_SERVER_NO_WEBUI → LLAMA_BUILD_WEBUI

* LLAMA_BUILD_WEBUI ON by default not based on LLAMA_STANDALONE

* MIssed this

* Add useWebUi to package.nix
2026-03-27 17:25:55 +01:00
Yiwei Shao ee051c1e4e hexagon: support for IQ4_NL and MXFP4 (#21018)
* ggml-hexagon: add IQ4_NL and MXFP4 HMX matmul support

- Add IQ4_NL quantization type support to Hexagon backend (buffer
  set/get tensor repack, mul_mat, mul_mat_id dispatch)
- Implement HVX IQ4_NL vec_dot kernels (1x1, 2x1, 2x2) with
  LUT-based 4-bit index to int8 kvalue dequantization
- Add MXFP4 HMX dequantization path with E8M0 scale conversion,
  including batch-4 fast path and single-tile fallback
- Unify quantized row size / scale offset logic to handle Q4_0,
  Q8_0, IQ4_NL, and MXFP4 in the DMA fetch path

* ggml-hexagon: fix SKIP_QUANTIZE src1 address mismatch in mixed-quant models

* Fix the pragma indent
2026-03-27 09:22:41 -07:00
Aleksander Grygier e6f6770515 webui: Improve Chat Messages initial scroll + auto-scroll logic + add lazy loading with transitions to content blocks (#20999)
* refactor: Always use agentic content renderer for Assistant Message

* feat: Improve initial scroll + auto-scroll logic + implement fade in action for content blocks

* chore: update webui build output
2026-03-27 17:01:36 +01:00
AN Long 48cda24c11 server: remove the verbose_prompt parameter (#21059)
* server: respect the verbose_prompt parameter

* Revert "server: respect the verbose_prompt parameter"

This reverts commit 8ed885cf37.

* Remove --verbose-prompt parameter from llama-server

* Using set_examples instead of set_excludes
2026-03-27 13:36:13 +02:00
Xuan-Son Nguyen 871f1a2d2f mtmd: add more sanity checks (#21047) 2026-03-27 11:00:52 +01:00
Xuan-Son Nguyen 20197b6fe3 server: add built-in tools backend support (#20898)
* wip: server_tools

* refactor

* displayName -> display_name

* snake_case everywhere

* rm redundant field

* change arg to --tools all

* add readme mention

* llama-gen-docs
2026-03-27 10:07:11 +01:00
73 changed files with 3182 additions and 482 deletions
+7 -4
View File
@@ -1,11 +1,13 @@
ARG UBUNTU_VERSION=22.04
ARG UBUNTU_VERSION=24.04
FROM ubuntu:$UBUNTU_VERSION AS build
ARG TARGETARCH
RUN apt-get update && \
apt-get install -y build-essential git cmake libssl-dev
apt-get install -y gcc-14 g++-14 build-essential git cmake libssl-dev
ENV CC=gcc-14 CXX=g++-14
WORKDIR /app
@@ -55,8 +57,9 @@ RUN apt-get update \
git \
python3 \
python3-pip \
&& pip install --upgrade pip setuptools wheel \
&& pip install -r requirements.txt \
python3-wheel \
&& pip install --break-system-packages --upgrade setuptools \
&& pip install --break-system-packages -r requirements.txt \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+2
View File
@@ -41,6 +41,7 @@
effectiveStdenv ? if useCuda then cudaPackages.backendStdenv else stdenv,
enableStatic ? effectiveStdenv.hostPlatform.isStatic,
precompileMetalShaders ? false,
useWebUi ? true,
}:
let
@@ -164,6 +165,7 @@ effectiveStdenv.mkDerivation (finalAttrs: {
cmakeFlags =
[
(cmakeBool "LLAMA_BUILD_SERVER" true)
(cmakeBool "LLAMA_BUILD_WEBUI" useWebUi)
(cmakeBool "BUILD_SHARED_LIBS" (!enableStatic))
(cmakeBool "CMAKE_SKIP_BUILD_RPATH" true)
(cmakeBool "GGML_NATIVE" false)
+11 -13
View File
@@ -36,18 +36,16 @@ jobs:
matrix:
config:
# Multi-stage build
# Note: the arm64 images are failing, which prevents the amd64 images from being built
# https://github.com/ggml-org/llama.cpp/issues/11888
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
- { tag: "cuda cuda12", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04", cuda_version: "12.4.0", ubuntu_version: "22.04" }
- { tag: "cuda13", dockerfile: ".devops/cuda-new.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04", cuda_version: "13.1.0", ubuntu_version: "24.04" }
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
- { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04-s390x" }
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
- { tag: "openvino", dockerfile: ".devops/openvino.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/arm64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
- { tag: "cuda cuda12", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "12.4.0", ubuntu_version: "22.04" }
- { tag: "cuda13", dockerfile: ".devops/cuda-new.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "13.1.0", ubuntu_version: "24.04" }
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
- { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04-s390x" }
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
- { tag: "openvino", dockerfile: ".devops/openvino.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
steps:
- name: Check out the repo
uses: actions/checkout@v6
@@ -58,7 +56,7 @@ jobs:
if: ${{ matrix.config.tag != 's390x' }}
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3
with:
image: tonistiigi/binfmt:qemu-v7.0.0-28
image: tonistiigi/binfmt:qemu-v10.2.1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
+1
View File
@@ -108,6 +108,7 @@ option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_WEBUI "llama: build the embedded Web UI for server" ON)
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
+17 -1
View File
@@ -1079,7 +1079,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.verbose_prompt = true;
}
));
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(
{"--display-prompt"},
{"--no-display-prompt"},
@@ -2807,6 +2807,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.port = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT"));
add_opt(common_arg(
{"--reuse-port"},
string_format("allow multiple sockets to bind to the same port (default: %s)", params.reuse_port ? "enabled" : "disabled"),
[](common_params & params) {
params.reuse_port = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_REUSE_PORT"));
add_opt(common_arg(
{"--path"}, "PATH",
string_format("path to serve static files from (default: %s)", params.public_path.c_str()),
@@ -2843,6 +2850,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.webui_mcp_proxy = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY"));
add_opt(common_arg(
{"--tools"}, "TOOL1,TOOL2,...",
"experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n"
"specify \"all\" to enable all tools\n"
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff",
[](common_params & params, const std::string & value) {
params.server_tools = parse_csv_row(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
add_opt(common_arg(
{"--webui"},
{"--no-webui"},
+18 -10
View File
@@ -287,7 +287,7 @@ void analyze_reasoning::compare_reasoning_presence() {
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())) + p.rest());
});
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
});
// try the more aggressive parse first, if it fails, fall back to the delimiter one
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
@@ -297,7 +297,7 @@ void analyze_reasoning::compare_reasoning_presence() {
if (result.result.success()) {
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
mode = reasoning_mode::TAG_BASED;
start = trim_whitespace(result.tags["pre"]);
start = trim_leading_whitespace(result.tags["pre"]);
end = trim_trailing_whitespace(result.tags["post"]);
} else if (!result.tags["post"].empty()) {
mode = reasoning_mode::TAG_BASED;
@@ -333,7 +333,7 @@ void analyze_reasoning::compare_thinking_enabled() {
if (left_trimmed.empty() && !diff.right.empty()) {
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
if (start.empty()) {
start = right_trimmed;
start = trim_leading_whitespace(diff.right);
mode = reasoning_mode::TAG_BASED;
}
}
@@ -344,7 +344,7 @@ void analyze_reasoning::compare_thinking_enabled() {
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
start = seg[seg.size() - 2].value;
}
end = left_trimmed;
end = trim_trailing_whitespace(diff.left);
mode = reasoning_mode::TAG_BASED;
}
}
@@ -363,15 +363,23 @@ void analyze_reasoning::compare_thinking_enabled() {
size_t len = std::min(base.size(), anchor_len);
std::string anchor = base.substr(base.size() - len);
auto pos = extended.rfind(anchor);
if (pos == std::string::npos || pos + len >= extended.size()) continue;
if (pos == std::string::npos || pos + len >= extended.size()) {
continue;
}
std::string extra = trim_whitespace(extended.substr(pos + len));
if (extra.empty()) continue;
if (extra.empty()) {
continue;
}
auto seg = prune_whitespace_segments(segmentize_markers(extra));
if (seg.size() == 2 && seg[0].type == segment_type::MARKER && seg[1].type == segment_type::MARKER) {
if (start.empty()) start = seg[0].value;
if (end.empty()) end = seg[1].value;
if (start.empty()) {
start = seg[0].value;
}
if (end.empty()) {
end = seg[1].value;
}
mode = reasoning_mode::TAG_BASED;
break;
}
@@ -423,7 +431,7 @@ void analyze_reasoning::compare_reasoning_scope() {
LOG_DBG(ANSI_ORANGE "%s: Detected TOOLS_ONLY reasoning mode\n" ANSI_RESET, __func__);
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space()));
return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space()));
});
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
if (result.result.success()) {
@@ -516,7 +524,7 @@ analyze_content::analyze_content(const common_chat_template & tmpl, const analyz
// Take the more promising diff
std::string pure_content = rdiff.length() > diff_tools.left.length() ? rdiff : diff_tools.left;
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
return p.tag("pre", p.marker()) + p.space() + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
return p.tag("pre", p.marker() + p.space()) + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
});
auto result = parser_wrapped.parse_anywhere_and_extract(pure_content);
start = result.tags["pre"];
+32
View File
@@ -656,6 +656,38 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
return true;
}
// simple glob: * matches non-/ chars, ** matches anything including /
static inline bool glob_match(const char * pattern, const char * str) {
if (*pattern == '\0') {
return *str == '\0';
}
if (pattern[0] == '*' && pattern[1] == '*') {
const char * p = pattern + 2;
if (*p == '/') p++;
if (glob_match(p, str)) return true;
if (*str != '\0') return glob_match(pattern, str + 1);
return false;
}
if (*pattern == '*') {
const char * p = pattern + 1;
for (; *str != '\0' && *str != '/'; str++) {
if (glob_match(p, str)) return true;
}
return glob_match(p, str);
}
if (*pattern == '?' && *str != '\0' && *str != '/') {
return glob_match(pattern + 1, str + 1);
}
if (*pattern == *str) {
return glob_match(pattern + 1, str + 1);
}
return false;
}
bool glob_match(const std::string & pattern, const std::string & str) {
return glob_match(pattern.c_str(), str.c_str());
}
//
// Filesystem utils
//
+6
View File
@@ -573,6 +573,7 @@ struct common_params {
// server params
int32_t port = 8080; // server listens on this network port
bool reuse_port = false; // allow multiple sockets to bind to the same port
int32_t timeout_read = 600; // http read timeout in seconds
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
@@ -613,6 +614,9 @@ struct common_params {
bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false;
// enable built-in tools
std::vector<std::string> server_tools;
// router server configs
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
@@ -790,6 +794,8 @@ std::string string_from(const std::vector<int> & values);
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
bool glob_match(const std::string & pattern, const std::string & str);
//
// Filesystem utils
//
+12 -11
View File
@@ -115,9 +115,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
break;
}
case REASONING_BUDGET_FORCING:
// force_pos is advanced in apply(), not here.
// This ensures the first forced token isn't skipped when the sampler
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
break;
case REASONING_BUDGET_DONE:
break;
@@ -144,14 +146,6 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
cur_p->data[i].logit = -INFINITY;
}
}
// advance to next forced token (done here rather than in accept so that
// the first forced token isn't skipped when starting in FORCING state)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
}
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
@@ -261,3 +255,10 @@ struct llama_sampler * common_reasoning_budget_init(
common_reasoning_budget_state initial_state) {
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
if (!smpl) {
return REASONING_BUDGET_IDLE;
}
return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
}
+2
View File
@@ -51,3 +51,5 @@ struct llama_sampler * common_reasoning_budget_init(
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state);
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
+46 -10
View File
@@ -7,6 +7,7 @@
#include <algorithm>
#include <cctype>
#include <climits>
#include <cmath>
#include <cstring>
#include <unordered_map>
@@ -109,6 +110,7 @@ struct common_sampler {
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * rbudget;
struct llama_sampler * chain;
ring_buffer<llama_token> prev;
@@ -188,6 +190,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
lparams.no_perf = params.no_perf;
llama_sampler * grmr = nullptr;
llama_sampler * rbudget = nullptr;
llama_sampler * chain = llama_sampler_chain_init(lparams);
std::vector<llama_sampler *> samplers;
@@ -270,7 +273,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
if (grmr) {
if (grmr && !params.grammar_lazy) {
try {
for (const auto & token : prefill_tokens) {
llama_sampler_accept(grmr, token);
@@ -284,15 +287,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
// reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init(
// reasoning budget sampler
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
rbudget = common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens,
prefill_tokens));
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
prefill_tokens);
}
if (params.has_logit_bias()) {
@@ -383,6 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .rbudget = */ rbudget,
/* .chain = */ chain,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
@@ -398,11 +402,27 @@ void common_sampler_free(struct common_sampler * gsmpl) {
}
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->rbudget);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
}
static bool grammar_should_apply(struct common_sampler * gsmpl) {
if (!gsmpl->grmr) {
return false;
}
if (!gsmpl->rbudget) {
return true;
}
if (gsmpl->params.grammar_lazy) {
// if grammar is lazy, only apply when reasoning budget is not active
const auto state = common_reasoning_budget_get_state(gsmpl->rbudget);
return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
}
return true;
}
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
if (!gsmpl) {
return;
@@ -410,6 +430,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
const auto tm = gsmpl->tm();
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
llama_sampler_accept(gsmpl->rbudget, token);
if (gsmpl->grmr && accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
}
@@ -431,6 +456,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .rbudget = */ llama_sampler_clone(gsmpl->rbudget),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
@@ -500,6 +526,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
llama_token id = LLAMA_TOKEN_NULL;
auto & grmr = gsmpl->grmr;
auto & rbudget = gsmpl->rbudget;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
@@ -511,7 +538,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
if (id != LLAMA_TOKEN_NULL) {
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
// TODO: simplify
gsmpl->cur.resize(1);
@@ -524,7 +552,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
gsmpl->set_logits(ctx, idx);
if (grammar_first) {
// apply reasoning budget first
llama_sampler_apply(rbudget, &cur_p);
if (grammar_first && grammar_should_apply(gsmpl)) {
llama_sampler_apply(grmr, &cur_p);
}
@@ -532,7 +563,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
id = cur_p.data[cur_p.selected].id;
if (grammar_first) {
if (grammar_first || !grammar_should_apply(gsmpl)) {
return id;
}
@@ -553,7 +584,12 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(rbudget, &cur_p);
if (grammar_should_apply(gsmpl)) {
llama_sampler_apply(grmr, &cur_p);
}
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
+36 -1
View File
@@ -1406,6 +1406,13 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
repack_q8_0_q8x4x2(tensor, data, size);
break;
case GGML_TYPE_IQ4_NL:
GGML_ASSERT(offset == 0);
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
// IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16])
repack_q4_0_q4x4x2(tensor, data, size);
break;
case GGML_TYPE_MXFP4:
GGML_ASSERT(offset == 0);
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
@@ -1442,6 +1449,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
repack_q8x4x2_q8_0(data, tensor, size);
break;
case GGML_TYPE_IQ4_NL:
GGML_ASSERT(offset == 0);
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
repack_q4x4x2_q4_0(data, tensor, size);
break;
case GGML_TYPE_MXFP4:
GGML_ASSERT(offset == 0);
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
@@ -1819,6 +1832,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
if (src0->ne[0] % 32) {
return false;
@@ -1868,6 +1882,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
if ((src0->ne[0] % 32)) {
return false;
@@ -2596,8 +2611,26 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
delete backend;
}
// Map weight type to its activation quantization family.
// Types in the same family produce identical Q8 formats in VTCM and can
// safely share quantized activation data via SKIP_QUANTIZE.
// When adding a new quantized type, assign it the correct family here.
static inline int act_quant_family(enum ggml_type wtype) {
switch (wtype) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
return 1; // Q8x4x2
default:
return 0; // unknown / not quantized
}
}
static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
return (op0 && op0->src[1] == op1->src[1] &&
act_quant_family(op0->src[0]->type) == act_quant_family(op1->src[0]->type) &&
act_quant_family(op0->src[0]->type) != 0);
}
static inline bool is_compute_op(ggml_tensor *node)
@@ -3364,6 +3397,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
"please update hexagon_type to match ggml_type");
static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
"please update hexagon_type to match ggml_type");
static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL,
"please update hexagon_type to match ggml_type");
const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
+193 -16
View File
@@ -30,6 +30,12 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
};
// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0,
};
static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
-127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0,
1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0,
@@ -46,7 +52,8 @@ static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned
// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
#define HMX_X4X2_SCALES_PER_BLK 8
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL)
#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4)
static inline void swap_ptr(void **p1, void **p2) {
void *t = *p1;
@@ -78,9 +85,11 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
return (size_t)nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
case HTP_TYPE_Q8_0:
return (size_t)nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
case HTP_TYPE_MXFP4:
return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb
default:
return 0;
}
@@ -284,6 +293,87 @@ static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
}
// --- MXFP4 E8M0 scale conversion and dequantization ---
//
// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack.
// Scalar loads from the stack array execute on the scalar pipeline, in parallel
// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop.
// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10
// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15.
typedef struct {
__fp16 v[8] __attribute__((aligned(16)));
} mxfp4_scales_t;
static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) {
mxfp4_scales_t s;
HVX_Vector v = hvx_vmemu(e8m0_8);
HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v));
vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112));
vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero());
vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30));
vh = Q6_Vh_vasl_VhR(vh, 10);
hvx_vec_store_u(s.v, 16, vh);
return s;
}
static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) {
return hvx_vec_splat_f16(scales.v[idx]);
}
// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16.
static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32,
bool upper_nibbles,
int sub_blk,
const HVX_Vector vlut_cvt,
mxfp4_scales_t scales) {
HVX_Vector vq = hvx_vmemu(packed_32);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_hf = Q6_V_lo_W(vp);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc));
}
// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
bool upper_nibbles,
int sub_blk_base,
const HVX_Vector vlut_cvt,
mxfp4_scales_t scales,
HVX_Vector out[4]) {
HVX_Vector vq = hvx_vmemu(packed_128);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
v_quants = Q6_Vb_vshuff_Vb(v_quants);
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
HVX_Vector v_lo = Q6_V_lo_W(vp);
HVX_Vector v_hi = Q6_V_hi_W(vp);
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0),
mxfp4_extract_splat(scales, sub_blk_base + 1));
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2),
mxfp4_extract_splat(scales, sub_blk_base + 3));
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
out[0] = v_lo;
out[1] = Q6_V_vror_VR(v_lo, 64);
out[2] = v_hi;
out[3] = Q6_V_vror_VR(v_hi, 64);
}
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
// Output: vtcm_dst in tile-major FP16 layout.
@@ -295,11 +385,11 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
int start_tile, int end_tile) {
const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
const int qrow_size = is_q4 ? (k_block / 2) : k_block;
const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2);
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL)
? hvx_vmem(iq4_nl_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut);
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
hvx_vmem(q4_0_to_fp16_lut);
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
// Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
@@ -312,8 +402,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
int ct = t / n_k_tiles; // column tile index
int kt = t % n_k_tiles; // K tile index
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
// --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) &&
((t + 3) / n_k_tiles == ct)) {
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
bool upper = (sub_blk_base >= 4);
@@ -351,10 +442,60 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
continue;
}
// --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4
bool upper = (sub_blk_base >= 4);
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales
__fp16 * tile_bases[4];
for (int g = 0; g < 4; g++) {
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
}
HVX_Vector v_off = v_scat_base;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t * r0 = vtcm_src + row0 * row_stride;
const uint8_t * r1 = vtcm_src + row1 * row_stride;
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
HVX_Vector v0[4], v1[4];
dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0);
if (row1 < n_cols) {
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1);
} else {
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
}
for (int g = 0; g < 4; g++) {
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]);
}
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
for (int g = 0; g < 4; g++) {
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]);
}
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
for (int g = 0; g < 4; g++) {
(void) *(volatile HVX_Vector *) (tile_bases[g]);
}
t += 4;
continue;
}
// --- Single-tile fallback ---
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
if (is_q4) {
if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) {
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
bool upper = (sub_blk >= 4);
@@ -382,6 +523,39 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
(void) *(volatile HVX_Vector *)(tile_base);
} else if (weight_type == HTP_TYPE_MXFP4) {
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
bool upper = (sub_blk >= 4);
int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
HVX_Vector v_off = v_scat_base;
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
int row1 = row0 + 1;
const uint8_t * r0 = vtcm_src + row0 * row_stride;
const uint8_t * r1 = vtcm_src + row1 * row_stride;
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
HVX_Vector v1;
if (row1 < n_cols) {
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
} else {
v1 = Q6_V_vzero();
}
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
}
(void) *(volatile HVX_Vector *) (tile_base);
} else {
// Q8_0
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
@@ -1455,21 +1629,24 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
{
qweight_fetch_task_state_t s;
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
const int blk_start = kk / QK_Q4_0x4x2;
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
const int full_qrow = is_q4 ? (k / 2) : k;
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
const int scale_blk_size =
(weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
s.dst = vtcm_scratch0;
s.src = w + nc * row_stride;
s.n_rows = n_blk_sz;
s.src_stride = row_stride;
s.dst_stride = sub_row_stride;
s.quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2)) : (blk_start * QK_Q8_0x4x2);
s.quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2)) : (nb_sub * QK_Q8_0x4x2);
s.scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE;
s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE;
s.quant_off =
(weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2));
s.quant_width =
(weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2));
s.scale_off = full_qrow + blk_start * scale_blk_size;
s.scale_width = nb_sub * scale_blk_size;
// 2D DMA: quants sub-range
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
+6
View File
@@ -31,6 +31,12 @@ struct htp_context {
uint32_t opmask;
// Cached src1 spad position from the last quantize pass.
// When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM
// at this address; the matmul must read from here instead of recomputing
// the offset (which depends on the current op's src0 size).
uint8_t * prev_src1_spad;
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
#ifdef HTP_HAS_HMX
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
+4 -6
View File
@@ -1114,14 +1114,12 @@ static void proc_hmx_matmul_req(struct htp_context * ctx,
return;
}
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
// Other types (e.g. MXFP4) fall back to HVX.
// HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
// Other types fall back to HVX.
{
uint32_t wtype = req->src0.type;
if (wtype != HTP_TYPE_F16 &&
wtype != HTP_TYPE_Q4_0 &&
wtype != HTP_TYPE_Q8_0 &&
wtype != HTP_TYPE_IQ4_NL) {
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL &&
wtype != HTP_TYPE_MXFP4) {
proc_matmul_req(ctx, req, bufs, n_bufs);
return;
}
+380
View File
@@ -60,6 +60,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
};
// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue
// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = {
0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0,
0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -68,6 +78,73 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
HVX_Vector v2_3 = vptr[1]; // ...
HVX_Vector v4_5 = vptr[2]; // ...
HVX_Vector v6_7 = vptr[3]; // ...
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
return r;
}
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
const uint32_t qk = QK_Q4_0x4x2; // 256
const uint32_t nb = n / qk;
const uint32_t nloe = n % qk;
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
HVX_Vector_x8 r;
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nb; i++) {
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
}
if (nloe) {
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
}
return r;
}
// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
static inline size_t q8x4x2_row_size(uint32_t ne) {
@@ -921,6 +998,293 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
// ======== IQ4_NL x Q8_0 vec_dot kernels ========
// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue).
// Scale format is identical to Q4_0 (fp16 scales).
static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n,
float * restrict s0,
const void * restrict vx0,
const void * restrict vy0) {
assert(n % 32 == 0);
assert((unsigned long) vx0 % 128 == 0);
assert((unsigned long) vy0 % 128 == 0);
const uint32_t qk = QK_Q4_0x4x2 * 4;
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t x_qblk_size = qk / 2; // int4
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t y_qblk_size = qk; // int8
const uint32_t y_qrow_size = n; // int8 (not padded)
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
HVX_Vector r0_sum = Q6_V_vzero();
const uint32_t nb = n / qk;
const uint32_t nloe = n % qk;
uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
if (nloe) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
}
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
hvx_vec_store_u(s0, 4, r0_sum);
}
static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n,
float * restrict s0,
const void * restrict vx0,
const void * restrict vx1,
const void * restrict vy0) {
assert(n % 32 == 0);
assert((unsigned long) vx0 % 128 == 0);
assert((unsigned long) vx1 % 128 == 0);
assert((unsigned long) vy0 % 128 == 0);
const uint32_t qk = QK_Q4_0x4x2 * 4;
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t x_qblk_size = qk / 2; // int4
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t y_qblk_size = qk; // int8
const uint32_t y_qrow_size = n; // int8 (not padded)
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
HVX_Vector r0_sum = Q6_V_vzero();
HVX_Vector r1_sum = Q6_V_vzero();
const uint32_t nb = n / qk;
const uint32_t nloe = n % qk;
uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
if (nloe) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
}
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
hvx_vec_store_u(s0, 8, rsum);
}
static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n,
float * restrict s0,
float * restrict s1,
const void * restrict vx0,
const void * restrict vx1,
const void * restrict vy0,
const void * restrict vy1) {
assert(n % 32 == 0);
assert((unsigned long) vx0 % 128 == 0);
assert((unsigned long) vx1 % 128 == 0);
assert((unsigned long) vy0 % 128 == 0);
assert((unsigned long) vy1 % 128 == 0);
const uint32_t qk = QK_Q4_0x4x2 * 4;
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t x_qblk_size = qk / 2; // int4
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t y_qblk_size = qk; // int8
const uint32_t y_qrow_size = n; // int8 (not padded)
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;
HVX_Vector r0_c0_sum = Q6_V_vzero();
HVX_Vector r0_c1_sum = Q6_V_vzero();
HVX_Vector r1_c0_sum = Q6_V_vzero();
HVX_Vector r1_c1_sum = Q6_V_vzero();
const uint32_t nb = n / qk;
const uint32_t nloe = n % qk;
uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
}
if (nloe) {
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
}
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);
}
static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
assert(n % 32 == 0); // min sub-block size
assert((unsigned long) vx0 % 128 == 0);
@@ -2393,6 +2757,12 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t
mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
return 0;
case HTP_TYPE_IQ4_NL:
mmctx->type = "iq4nlx4x2-f32";
mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1;
mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1;
mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2;
return 0;
case HTP_TYPE_MXFP4:
mmctx->type = "mxfp4x4x2-f32";
mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
@@ -2556,6 +2926,13 @@ int op_matmul(struct htp_ops_context * octx) {
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
// Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it
octx->ctx->prev_src1_spad = octx->src1_spad.data;
} else {
// SKIP_QUANTIZE: Q8 data lives at the address written by the previous
// quantize pass. The current op may have a different src0 size (e.g.
// IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong.
octx->src1_spad.data = octx->ctx->prev_src1_spad;
}
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
@@ -2659,6 +3036,9 @@ int op_matmul_id(struct htp_ops_context * octx) {
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
octx->ctx->prev_src1_spad = octx->src1_spad.data;
} else {
octx->src1_spad.data = octx->ctx->prev_src1_spad;
}
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+27 -12
View File
@@ -1112,6 +1112,16 @@ struct vk_op_glu_push_constants {
uint32_t mode; // 0: default, 1: swapped, 2: split
float alpha; // for swiglu_oai
float limit;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t ne01;
uint32_t ne02;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t ne11;
uint32_t ne12;
};
struct vk_op_unary_push_constants {
@@ -5044,7 +5054,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
} else {
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
}
vk::DeviceCreateInfo device_create_info;
vk::DeviceCreateInfo device_create_info{};
std::vector<const char *> device_extensions;
vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
@@ -5413,12 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
#endif
device->name = GGML_VK_NAME + std::to_string(idx);
device_create_info = {
vk::DeviceCreateFlags(),
device_queue_create_infos,
{},
device_extensions
};
device_create_info
.setFlags(vk::DeviceCreateFlags())
.setQueueCreateInfos(device_queue_create_infos)
.setPEnabledExtensionNames(device_extensions);
device_create_info.setPNext(&device_features2);
device->device = device->physical_device.createDevice(device_create_info);
@@ -11048,8 +11056,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
const float alpha = op_params_f[2];
const float limit = op_params_f[3];
GGML_ASSERT(ggml_is_contiguous(src0));
if (!split) {
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
} else {
@@ -11067,7 +11073,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
(uint32_t)dst->ne[0],
mode,
alpha,
limit
limit,
(uint32_t)(src0->nb[1] / src0->nb[0]),
(uint32_t)(src0->nb[2] / src0->nb[0]),
(uint32_t)(src0->nb[3] / src0->nb[0]),
(uint32_t)src0->ne[1],
(uint32_t)src0->ne[2],
(uint32_t)(dst->nb[1] / dst->nb[0]),
(uint32_t)(dst->nb[2] / dst->nb[0]),
(uint32_t)(dst->nb[3] / dst->nb[0]),
(uint32_t)dst->ne[1],
(uint32_t)dst->ne[2]
});
}
@@ -15217,8 +15233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(op->src[0]->type == op->type);
default:
@@ -16,4 +16,14 @@ layout (push_constant) uniform parameter
uint mode;
float alpha;
float limit;
uint nb01;
uint nb02;
uint nb03;
uint ne01;
uint ne02;
uint nb11;
uint nb12;
uint nb13;
uint ne11;
uint ne12;
} p;
@@ -8,22 +8,32 @@ void main() {
const uint row = i / p.ne20;
const uint col = i - row * p.ne20;
const uint i3 = row / (p.ne01 * p.ne02);
const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01;
const uint i1 = row % p.ne01;
const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col;
const uint dst_i3 = row / (p.ne11 * p.ne12);
const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11;
const uint dst_i1 = row % p.ne11;
const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col;
if (p.mode == 0) {
// Default
const uint offset = p.ne00 / 2;
const uint idx = row * p.ne00 + col;
const uint idx = src_idx;
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
} else if (p.mode == 1) {
// Swapped
const uint offset = p.ne00 / 2;
const uint idx = row * p.ne00 + col;
const uint idx = src_idx;
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
} else {
// Split
const uint idx = row * p.ne00 + col;
const uint idx = src_idx;
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
}
}
+1 -1
View File
@@ -5,7 +5,7 @@ import os
import sys
import subprocess
HTTPLIB_VERSION = "refs/tags/v0.39.0"
HTTPLIB_VERSION = "refs/tags/v0.40.0"
vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
+1 -1
View File
@@ -1330,7 +1330,7 @@ static void test_nemotron_reasoning_detection(testing & t) {
analysis.analyze_template(tmpl);
// Check reasoning markers
t.assert_equal("reasoning_start should be '<think>'", "<think>", analysis.reasoning.start);
t.assert_equal("reasoning_start should be '<think>\\n'", "<think>\n", analysis.reasoning.start);
t.assert_equal("reasoning_end should be '</think>'", "</think>", analysis.reasoning.end);
// Check reasoning mode detection
+332 -81
View File
@@ -805,7 +805,8 @@ struct peg_test_case {
common_chat_templates_inputs params;
std::string input;
common_chat_msg expect;
bool is_partial = false;
bool is_partial = false;
bool expect_reconstruction = false;
};
struct make_peg_parser {
@@ -828,6 +829,12 @@ struct make_peg_parser {
}
};
// Global template filter for --template flag
static std::string g_template_filter;
// When true, run reconstruction test on every non-partial test and report results
static bool g_force_reconstruction_test = false;
static void test_peg_parser(common_chat_templates * tmpls,
const std::function<void(peg_test_case &)> & init,
bool detailed_debug) {
@@ -936,75 +943,158 @@ static void test_peg_parser(common_chat_templates * tmpls,
throw std::runtime_error("Failed to build grammar: " + parser.params_.grammar);
}
// Find the earliest trigger position to determine the constrained portion
auto earliest_trigger_pos = std::string::npos;
for (const auto & trigger : parser.params_.grammar_triggers) {
size_t pos = std::string::npos;
std::smatch match;
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
pos = tc.input.find(word);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
const auto & pattern = std::regex(trigger.value);
if (std::regex_search(tc.input, match, pattern)) {
pos = match.position(pattern.mark_count());
}
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
const auto & pattern = trigger.value;
if (std::regex_match(tc.input, match, std::regex(pattern))) {
auto mpos = std::string::npos;
for (size_t i = 1; i < match.size(); ++i) {
if (match[i].length() > 0) {
mpos = match.position(i);
break;
}
}
if (mpos == std::string::npos) {
mpos = match.position(0);
}
pos = mpos;
}
break;
}
default:
throw std::runtime_error("Unknown trigger type");
}
if (pos != std::string::npos) {
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
earliest_trigger_pos = pos;
// In production, grammar triggers match against the full generated text
// including the generation prompt. All positions are in full_input coordinates.
const auto & gen_prompt = parser.params_.generation_prompt;
std::string full_input = gen_prompt + tc.input;
// Determine whether the reasoning-budget sampler path applies: tool-call grammar
// with all WORD triggers and thinking tags present. In production, the reasoning
// budget sampler inhibits grammar application while inside thinking blocks —
// triggers inside <think>...</think> are suppressed.
bool use_reasoning_budget_path = false;
if (parser.params_.grammar_lazy && !parser.params_.thinking_end_tag.empty()) {
use_reasoning_budget_path = true;
for (const auto & trigger : parser.params_.grammar_triggers) {
if (trigger.type != COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
use_reasoning_budget_path = false;
break;
}
}
}
// Determine the constrained portion of input to test against grammar
std::string constrained = tc.input;
// Find the earliest trigger position to determine the constrained portion
auto earliest_trigger_pos = std::string::npos;
if (use_reasoning_budget_path) {
// Reasoning-budget path: simulate thinking-aware trigger detection.
// Walk through full_input tracking thinking state; only match triggers
// when outside thinking blocks.
const auto & think_start = parser.params_.thinking_start_tag;
const auto & think_end = parser.params_.thinking_end_tag;
bool in_thinking = false;
for (size_t i = 0; i < full_input.size(); ++i) {
if (!in_thinking && !think_start.empty()
&& full_input.compare(i, think_start.size(), think_start) == 0) {
in_thinking = true;
i += think_start.size() - 1;
continue;
}
if (in_thinking && full_input.compare(i, think_end.size(), think_end) == 0) {
in_thinking = false;
i += think_end.size() - 1;
continue;
}
if (in_thinking) {
continue;
}
// Outside thinking — check if any trigger word starts here
for (const auto & trigger : parser.params_.grammar_triggers) {
if (full_input.compare(i, trigger.value.size(), trigger.value) == 0) {
if (earliest_trigger_pos == std::string::npos || i < earliest_trigger_pos) {
earliest_trigger_pos = i;
}
}
}
if (earliest_trigger_pos != std::string::npos) {
break; // found the earliest
}
}
// If the reasoning-budget path found no trigger outside thinking but the test
// expects tool calls, this template nests tool calls inside thinking
// blocks (e.g. Kimi). Fall back to the legacy path for this case.
if (earliest_trigger_pos == std::string::npos && !tc.expect.tool_calls.empty()) {
use_reasoning_budget_path = false;
}
}
if (!use_reasoning_budget_path) {
// Legacy path: find triggers without thinking-awareness
for (const auto & trigger : parser.params_.grammar_triggers) {
size_t pos = std::string::npos;
std::smatch match;
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
pos = full_input.find(word);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
const auto & compiled = std::regex(trigger.value);
if (std::regex_search(full_input, match, compiled)) {
pos = match.position(compiled.mark_count());
}
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
// In production, PATTERN_FULL triggers are checked against
// the text generated so far, growing token by token. Simulate
// by trying every prefix of full_input.
const auto & compiled = std::regex(trigger.value);
for (size_t end = gen_prompt.size(); end <= full_input.size(); ++end) {
std::string prefix = full_input.substr(0, end);
if (std::regex_match(prefix, match, compiled)) {
pos = std::string::npos;
for (size_t gi = 1; gi < match.size(); ++gi) {
if (match[gi].length() > 0) {
pos = match.position(gi);
break;
}
}
if (pos == std::string::npos) {
pos = match.position(0);
}
break;
}
}
break;
}
default:
throw std::runtime_error("Unknown trigger type");
}
if (pos != std::string::npos) {
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
earliest_trigger_pos = pos;
}
}
}
}
// If the test expects tool calls and the grammar is lazy, the trigger must fire.
// Otherwise the grammar would never activate in production and tool calls wouldn't
// be constrained. A silent skip here would hide broken triggers.
if (parser.params_.grammar_lazy && !tc.expect.tool_calls.empty() && !tc.is_partial
&& earliest_trigger_pos == std::string::npos) {
std::string trigger_desc;
for (const auto & trigger : parser.params_.grammar_triggers) {
trigger_desc += "\n [type=" + std::to_string(trigger.type) + "] " + trigger.value;
}
throw std::runtime_error(
"Grammar trigger did not fire, but test expects tool calls (lazy grammar).\n"
">>> Input: " + full_input + "\n"
">>> Triggers (" + std::to_string(parser.params_.grammar_triggers.size()) + "):" + trigger_desc);
}
// Determine the constrained portion of input to test against grammar.
// If the trigger position falls inside the generation prompt, the grammar
// sampler was already active before model output began — constrain from the
// start of the model output (i.e. tc.input).
std::string constrained = full_input;
bool grammar_triggered = false;
if (earliest_trigger_pos != std::string::npos) {
constrained = tc.input.substr(earliest_trigger_pos);
auto constrain_from = std::max(earliest_trigger_pos, gen_prompt.size());
constrained = full_input.substr(constrain_from);
grammar_triggered = true;
} else if (!parser.params_.grammar_lazy) {
// For non-lazy grammars, the entire input should match
grammar_triggered = true;
}
// For non-lazy grammars, prepend reasoning prefill to grammar input, just like
// PEG parsing does. The grammar includes the full reasoning pattern (e.g. optional
// <think>...</think>), but the model output may start mid-reasoning if the template
// already placed the opening tag in the prompt.
// For lazy grammars, the grammar only activates from the trigger position, so the
// reasoning prefill is irrelevant — reasoning is handled by the PEG parser.
if (!parser.params_.generation_prompt.empty() && earliest_trigger_pos == std::string::npos) {
constrained = parser.params_.generation_prompt + constrained;
}
// Test the constrained portion against the grammar
if (grammar_triggered && !tc.is_partial) {
auto result = match_string_detailed(constrained, grammar.get());
@@ -1036,10 +1126,57 @@ static void test_peg_parser(common_chat_templates * tmpls,
}
}
}
}
// Global template filter for --template flag
static std::string g_template_filter;
// Reconstruction test: verify that appending the parsed message to the original
// messages and re-rendering the template (without generation prompt) reproduces
// the original prompt + input exactly, or as a proper prefix (the template may
// append end-of-turn tokens after the assistant message).
if ((tc.expect_reconstruction || g_force_reconstruction_test) && !tc.is_partial) {
// Start from tc.expect but copy tool call arguments from the actual parser
// output, which preserves original JSON formatting (e.g. {"arg1":1} vs {"arg1": 1}).
auto reconstruction_msg = tc.expect;
auto parsed_msg = parser.parse(tc.input, false);
for (size_t i = 0; i < reconstruction_msg.tool_calls.size() && i < parsed_msg.tool_calls.size(); i++) {
reconstruction_msg.tool_calls[i].arguments = parsed_msg.tool_calls[i].arguments;
}
common_chat_templates_inputs reconstruction_inputs = tc.params;
reconstruction_inputs.messages.push_back(reconstruction_msg);
reconstruction_inputs.add_generation_prompt = false;
auto reconstruction_params = common_chat_templates_apply(tmpls, reconstruction_inputs);
std::string expected_text = parser.params_.prompt + tc.input;
bool match = reconstruction_params.prompt == expected_text ||
(reconstruction_params.prompt.size() > expected_text.size() &&
reconstruction_params.prompt.compare(0, expected_text.size(), expected_text) == 0);
if (!match && g_force_reconstruction_test && !tc.expect_reconstruction) {
// In forced mode, report mismatch but don't fail
// Find the first difference position
size_t diff_pos = 0;
size_t min_len = std::min(expected_text.size(), reconstruction_params.prompt.size());
while (diff_pos < min_len && expected_text[diff_pos] == reconstruction_params.prompt[diff_pos]) {
diff_pos++;
}
size_t ctx_start = diff_pos > 60 ? diff_pos - 60 : 0;
size_t ctx_end_e = std::min(expected_text.size(), diff_pos + 40);
size_t ctx_end_r = std::min(reconstruction_params.prompt.size(), diff_pos + 40);
LOG_ERR("\x1b[31m[RECONSTRUCTION FAIL]\x1b[0m "
"first diff at byte %zu (expected len=%zu, reconstructed len=%zu)\n"
" expected: ...%s...\n"
" reconstructed: ...%s...\n",
diff_pos, expected_text.size(), reconstruction_params.prompt.size(),
expected_text.substr(ctx_start, ctx_end_e - ctx_start).c_str(),
reconstruction_params.prompt.substr(ctx_start, ctx_end_r - ctx_start).c_str());
} else if (!match) {
std::string error_msg =
"Reconstruction mismatch:\n\n"
">>> Expected (prompt + input):\n" + expected_text +
"\n\n>>> Reconstructed:\n" + reconstruction_params.prompt;
throw std::runtime_error(error_msg);
} else if (g_force_reconstruction_test) {
LOG_INF("\x1b[32m[RECONSTRUCTION OK]\x1b[0m\n");
}
}
}
// Fluent builder for PEG parser tests
class peg_test_builder;
@@ -1099,6 +1236,11 @@ class peg_test_builder {
return *this;
}
peg_test_builder & expect_reconstruction(bool val = true) {
tc_.expect_reconstruction = val;
return *this;
}
// Expect setters
peg_test_builder & expect(const common_chat_msg & msg) {
tc_.expect = msg;
@@ -1272,16 +1414,18 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Ministral-3-14B-Reasoning-2512
auto tst = peg_tester("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
.expect_content("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
.expect_reconstruction()
.run();
tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.expect(message_assist_thoughts)
.expect_reconstruction()
.run();
tst.test(R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})")
@@ -1311,6 +1455,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.expect_reconstruction()
.run();
tst.test(
@@ -1323,6 +1468,20 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_reasoning("I need to output the invoice details in JSON")
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
.run();
// fake tool call marker in reasoning
tst.test(
"[THINK]Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more[/THINK]"
R"([TOOL_CALLS]special_function[ARGS]{"arg1": 1})")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.tools({ special_function_tool })
.expect_reasoning("Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more")
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
})
.expect_reconstruction()
.run();
}
{
@@ -1425,6 +1584,50 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_reasoning("I need to output the invoice details in JSON")
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
.run();
// tool call segment in reasoning
tst.test(
"Let's call a tool: <tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Not the real call!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call></think>\n"
"<tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Hello, world!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call>"
)
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({
python_tool
})
.expect_reasoning("Let's call a tool: <tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Not the real call!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call>")
.expect_tool_calls({
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
})
.run();
}
{
@@ -1481,9 +1684,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Google Gemma 2 2B - does not support tool calling
auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja");
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run();
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).expect_reconstruction().run();
tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).run();
tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).expect_reconstruction().run();
}
{
@@ -1526,7 +1729,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Test simple content-only template
auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
}
{
// IBM Granite (reasoning and tool calling model)
@@ -1638,7 +1841,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Qwen3-Coder (tool calling with XML-style format)
auto tst = peg_tester("models/templates/Qwen3-Coder.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test(
"<tool_call>\n"
@@ -1650,6 +1853,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
"</tool_call>")
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
tst.test(
@@ -1678,6 +1882,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.expect_reconstruction()
.run();
// Test with code content (multiline)
@@ -1698,6 +1903,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_tool_calls({
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
})
.expect_reconstruction()
.run();
// Test with code content (asian unicode chars)
@@ -1715,6 +1921,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_tool_calls({
{ "python", "{\"code\": \"\"}", {} },
})
.expect_reconstruction()
.run();
// Test with HTML tag content
@@ -1736,6 +1943,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_tool_calls({
{ "html", "{\"markup\": \"<html>\\n <head>\\n <title>Hello!</title>\\n </head>\\n</html>\"}", {} },
})
.expect_reconstruction()
.run();
// Test with TODO list (array of objects)
@@ -1753,6 +1961,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_tool_calls({
{ "todo_list", "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}", {} },
})
.expect_reconstruction()
.run();
// Test flexible optional argument ordering (2 required + 4 optional, reversed optional order)
@@ -1769,6 +1978,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_tool_calls({
{ "tool_2req_4opt", R"({"req1": "hello", "req2": 42, "opt4": 100, "opt2": 200})", {} },
})
.expect_reconstruction()
.run();
// Test flexible optional argument ordering (2 required + 5 optional, reversed optional order)
@@ -1786,6 +1996,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_tool_calls({
{ "tool_2req_5opt", R"({"req1": "world", "req2": 7, "opt5": "last", "opt3": "middle", "opt1": "first"})", {} },
})
.expect_reconstruction()
.run();
// Test flexible optional argument ordering (2 required + 5 optional, all 5 in shuffled order)
@@ -1805,6 +2016,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_tool_calls({
{ "tool_2req_5opt", R"({"req1": "test", "req2": 99, "opt3": "c", "opt1": "a", "opt5": "e", "opt4": 4, "opt2": 2})", {} },
})
.expect_reconstruction()
.run();
}
{
@@ -1885,6 +2097,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
tst.test("Hello, world!\nWhat's up?")
.enable_thinking(false)
.expect(message_assist)
.expect_reconstruction()
.run();
// Reasoning with content (forced-open mode - input starts after <think>)
@@ -1892,6 +2105,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.expect(message_assist_thoughts)
.expect_reconstruction()
.run();
// Tool call without reasoning
@@ -1902,6 +2116,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.enable_thinking(false)
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
// Tool call with reasoning (forced-open mode)
@@ -1914,6 +2129,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.expect_reconstruction()
.run();
tst.test(
@@ -1933,6 +2149,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.expect_reconstruction()
.run();
// #20650: tool with no required args, model emits <tool_call>name</tool_call> with no arg tags.
@@ -1950,6 +2167,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.tools({ no_args_tool })
.expect_reasoning("Let me read the diff content.")
.expect_tool_calls({{ "read_file_diff_md", "{}", {} }})
.expect_reconstruction()
.run();
}
}
@@ -2208,22 +2426,24 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Kimi-K2 old template
auto tst = peg_tester("models/templates/moonshotai-Kimi-K2.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test(
"<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>"
"{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>")
.tools({ special_function_tool })
.expect(kimi_id_special_func_tool_call)
.expect_reconstruction()
.run();
// Kimi-K2-Instruct
auto tst2 = peg_tester("models/templates/Kimi-K2-Instruct.jinja", detailed_debug);
tst2.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst2.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst2.test(
"<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>"
"{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>")
.tools({ special_function_tool })
.expect(kimi_id_special_func_tool_call)
.expect_reconstruction()
.run();
}
@@ -2297,6 +2517,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
// fake tool call marker in reasoning
tst.test(
"<think>Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm</think>"
"<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect_reasoning("Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm")
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
})
.run();
}
// Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format
@@ -2306,6 +2539,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
tst.test("<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>")
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
}
@@ -2314,7 +2548,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{
auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug);
tst.test(
"</think><minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
"<minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
"name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>")
.tools({ special_function_tool })
.expect(message_assist_call)
@@ -2364,37 +2598,41 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// mistralai-Mistral-Nemo-Instruct-2407.jinja
{
auto tst = peg_tester("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test("[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]")
.tools({ special_function_tool })
.expect(message_assist_call_id)
.expect_reconstruction()
.run();
}
{
auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.1.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test("<function=special_function>{\"arg1\": 1}</function>")
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
}
// Functionary v3.2 - recipient-based format: >>>recipient\n{content}
{
auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.2.jinja", detailed_debug);
tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).run();
tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test("special_function\n{\"arg1\": 1}")
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
}
// FireFunction
{
auto tst = peg_tester("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test(" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]")
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
}
@@ -2455,10 +2693,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
{ "models/templates/MiMo-VL.jinja", "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja",
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja" }) {
auto tst = peg_tester(path, detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test("<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n</tool_call>")
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
}
@@ -2481,6 +2720,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.expect(simple_assist_msg("Hello, world!\nWhat's up?", "Here are my reasoning steps:\nI'm\nthinking"))
.expect_reconstruction()
.run();
// Reasoning + Tool calls
@@ -2497,42 +2737,45 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// Mistral Small 3.2 - FUNC_BRACKET_TAG format: [TOOL_CALLS]func_name[CALL_ID]id[ARGS]{...}
{
auto tst = peg_tester("models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test("[TOOL_CALLS]special_function[CALL_ID]123456789[ARGS]{\"arg1\": 1}")
.tools({ special_function_tool })
.expect(message_assist_call_id)
.expect_reconstruction()
.run();
}
// Devstral
{
auto tst = peg_tester("models/templates/unsloth-mistral-Devstral-Small-2507.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test("[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}")
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
tst.test("Hello, world!\nWhat's up?[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}")
.tools({ special_function_tool })
.expect(message_assist_call_content)
.expect_reconstruction()
.run();
}
{
// Llama 3.1
auto tst = peg_tester("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run();
}
{
// Llama 3.2
auto tst = peg_tester("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run();
}
{
// Llama 3.3
auto tst = peg_tester("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).expect_reconstruction().run();
}
// GPT-OSS format tests
@@ -2836,10 +3079,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
// GigaChat V3
{
auto tst = peg_tester("models/templates/GigaChat3-10B-A1.8B.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test("<|message_sep|>\n\nfunction call<|role_sep|>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}")
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
tst.test(
@@ -2848,16 +3092,18 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
)
.tools({ special_function_tool })
.expect(message_assist_call_content)
.expect_reconstruction()
.run();
}
// GigaChat V3.1
{
auto tst = peg_tester("models/templates/GigaChat3.1-10B-A1.8B.jinja", detailed_debug);
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
tst.test("<|function_call|>{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}")
.tools({ special_function_tool })
.expect(message_assist_call)
.expect_reconstruction()
.run();
tst.test(
@@ -2866,6 +3112,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
)
.tools({ special_function_tool })
.expect(message_assist_call_content)
.expect_reconstruction()
.run();
}
}
@@ -3002,6 +3249,10 @@ int main(int argc, char ** argv) {
detailed_debug = true;
common_log_set_verbosity_thold(999);
}
if (arg == "--force-reconstruction-test") {
g_force_reconstruction_test = true;
only_run_filtered = true;
}
}
if (only_run_filtered) {
+13 -14
View File
@@ -61,8 +61,6 @@ static void test_reasoning_budget(
// Feed the sequence and track when forcing occurs
for (size_t i = 0; i < sequence.size(); i++) {
llama_sampler_accept(sampler, sequence[i]);
// Check if we're in forcing state by applying and seeing if logits are modified
cur_p.selected = -1;
for (size_t j = 0; j < cur.size(); j++) {
@@ -81,6 +79,8 @@ static void test_reasoning_budget(
}
}
llama_sampler_accept(sampler, sequence[i]);
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
if (finite_count == 1) {
@@ -167,9 +167,9 @@ int main(void) {
}
// Test 2: Budget exhausted, forcing should occur
// Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING
// Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state)
// At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE
// Flow: i=0 apply()->passthrough, accept(100)->COUNTING; i=1 accept(50)->remaining=1
// i=2 accept(51)->remaining=0->FORCING; i=3 apply() forces token[0]; i=4 apply() forces token[1]
// At i=4, accept() advances force_pos to 2 which equals forced_tokens.size(), so state becomes DONE
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
@@ -179,13 +179,12 @@ int main(void) {
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
2, // budget of 2 tokens
REASONING_BUDGET_IDLE,
2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces)
3); // forcing continues through i=3 (at i=4 state becomes DONE)
3, // forcing starts at i=3 (accept at i=2 depletes budget, apply at i=3 forces)
4); // forcing continues through i=4 (accept at i=4 transitions to DONE)
}
// Test 3: Activate immediately with budget=0, forcing should start right away
// Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough)
// This test needs start token to be in the sequence or use activate_immediately with start token present
// Flow: init promotes COUNTING+budget=0 to FORCING, so apply() sees FORCING at i=0
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
@@ -195,8 +194,8 @@ int main(void) {
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
0, // budget of 0 tokens
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING)
1); // forcing continues through i=1 (at i=2 state becomes DONE)
0, // forcing starts at i=0 (initialized in FORCING, apply forces immediately)
1); // forcing continues through i=1 (accept at i=1 transitions to DONE)
}
// Test 4: No start/end tokens configured - passthrough (no forcing)
@@ -214,7 +213,7 @@ int main(void) {
// Test 5: Activate immediately with budget > 0, count down then force
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
// So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0)
// Forcing starts at i=2 (apply sees FORCING after accept at i=1 transitioned)
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
@@ -224,8 +223,8 @@ int main(void) {
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
2, // budget of 2 tokens
REASONING_BUDGET_COUNTING,
1, // forcing starts at i=1 (after 2 accepts deplete budget)
2); // forcing continues through i=2
2, // forcing starts at i=2 (after 2 accepts deplete budget, apply at i=2 forces)
3); // forcing continues through i=3
}
printf("OK (5 tests passed)\n");
+81 -18
View File
@@ -100,7 +100,7 @@ struct cli_context {
}
// reasoning budget sampler
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
if (!chat_params.thinking_end_tag.empty()) {
const llama_vocab * vocab = llama_model_get_vocab(
llama_get_model(ctx_server.get_llama_context()));
@@ -224,10 +224,11 @@ struct cli_context {
};
// TODO?: Make this reusable, enums, docs
static const std::array<const std::string, 6> cmds = {
static const std::array<const std::string, 7> cmds = {
"/audio ",
"/clear",
"/exit",
"/glob ",
"/image ",
"/read ",
"/regen",
@@ -258,7 +259,7 @@ static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std:
}
}
if (!cmd.empty() && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
auto cur_dir = std::filesystem::current_path();
@@ -339,6 +340,8 @@ static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std:
return matches;
}
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
int main(int argc, char ** argv) {
common_params params;
@@ -430,7 +433,8 @@ int main(int argc, char ** argv) {
console::log(" /exit or Ctrl+C stop or exit\n");
console::log(" /regen regenerate the last response\n");
console::log(" /clear clear the chat history\n");
console::log(" /read add a text file\n");
console::log(" /read <file> add a text file\n");
console::log(" /glob <pattern> add text files using globbing pattern\n");
if (inf.has_inp_image) {
console::log(" /image <file> add an image file\n");
}
@@ -441,6 +445,27 @@ int main(int argc, char ** argv) {
// interactive loop
std::string cur_msg;
auto add_text_file = [&](const std::string & fname) -> bool {
std::string marker = ctx_cli.load_input_file(fname, false);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
return false;
}
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
cur_msg += fname;
cur_msg.push_back('\n');
} else {
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
}
cur_msg += marker;
console::log("Loaded text from '%s'\n", fname.c_str());
return true;
};
while (true) {
std::string buffer;
console::set_display(DISPLAY_TYPE_USER_INPUT);
@@ -525,22 +550,60 @@ int main(int argc, char ** argv) {
continue;
} else if (string_starts_with(buffer, "/read ")) {
std::string fname = string_strip(buffer.substr(6));
std::string marker = ctx_cli.load_input_file(fname, false);
if (marker.empty()) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
continue;
add_text_file(fname);
continue;
} else if (string_starts_with(buffer, "/glob ")) {
std::error_code ec;
size_t count = 0;
auto curdir = std::filesystem::current_path();
std::string pattern = string_strip(buffer.substr(6));
std::filesystem::path rel_path;
auto startglob = pattern.find_first_of("![*?");
if (startglob != std::string::npos && startglob != 0) {
auto endpath = pattern.substr(0, startglob).find_last_of('/');
if (endpath != std::string::npos) {
std::string rel_pattern = pattern.substr(0, endpath);
#if !defined(_WIN32)
if (string_starts_with(rel_pattern, "~")) {
const char * home = std::getenv("HOME");
if (home && home[0]) {
rel_pattern = std::string(home) + rel_pattern.substr(1);
}
}
#endif
rel_path = rel_pattern;
pattern.erase(0, endpath + 1);
curdir /= rel_path;
}
}
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
cur_msg += fname;
cur_msg.push_back('\n');
} else {
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
std::filesystem::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) {
continue;
}
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
if (ec) {
ec.clear();
continue;
}
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(pattern, rel)) {
continue;
}
if (!add_text_file((rel_path / rel).string())) {
continue;
}
if (++count >= FILE_GLOB_MAX_RESULTS) {
console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS);
break;
}
}
cur_msg += marker;
console::log("Loaded text from '%s'\n", fname.c_str());
continue;
} else {
// not a command
+10
View File
@@ -1377,6 +1377,16 @@ struct clip_model_loader {
// sanity check
{
if (hparams.image_size < 0) {
// note: some models having hparams.image_size == 0, which means the image size is dynamic
throw std::runtime_error(string_format("%s: image_size (%d) cannot be negative\n", __func__, hparams.image_size));
}
if (hparams.patch_size <= 0) {
throw std::runtime_error(string_format("%s: patch_size (%d) must be greater than 0\n", __func__, hparams.patch_size));
}
if (hparams.n_embd <= 0) {
throw std::runtime_error(string_format("%s: n_embd (%d) must be greater than 0\n", __func__, hparams.n_embd));
}
if (hparams.image_max_pixels < hparams.image_min_pixels) {
throw std::runtime_error(string_format("%s: image_max_pixels (%d) is less than image_min_pixels (%d)\n", __func__, hparams.image_max_pixels, hparams.image_min_pixels));
}
+10 -8
View File
@@ -13,23 +13,20 @@
constexpr bool DEBUG = false;
void mtmd_audio_cache::fill_sin_cos_table(int n) {
void mtmd_audio_cache::fill_sin_cos_table(uint32_t n) {
sin_vals.resize(n);
cos_vals.resize(n);
for (int i = 0; i < n; i++) {
for (uint32_t i = 0; i < n; i++) {
double theta = (2 * M_PI * i) / n;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
}
void mtmd_audio_cache::fill_hann_window(int length, bool periodic) {
void mtmd_audio_cache::fill_hann_window(uint32_t length, bool periodic) {
hann_window.resize(length);
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
int offset = periodic ? 0 : -1;
for (uint32_t i = 0; i < length; i++) {
hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
@@ -165,6 +162,7 @@ static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, fl
// false = input is complex-valued (interleaved real/imag, stride 2)
template <bool Inverse, bool RealInput>
static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) {
GGML_ASSERT(N > 0);
const int n_sin_cos_vals = cache.sin_vals.size();
if (N == 1) {
@@ -407,6 +405,8 @@ static bool log_mel_spectrogram(
}
GGML_ASSERT(params.n_fft_bins > 0);
GGML_ASSERT(params.hop_length > 0);
out.n_mel = params.n_mel;
out.n_len = (n_samples - frame_size) / frame_step + 1;
// TODO: handle these checks better
@@ -438,6 +438,7 @@ static bool log_mel_spectrogram(
const int effective_n_len = n_samples_in / frame_step;
if (params.norm_per_feature) {
GGML_ASSERT(effective_n_len > 1);
for (int i = 0; i < out.n_mel; i++) {
double mean = 0;
for (int j = 0; j < effective_n_len; ++j) {
@@ -639,6 +640,7 @@ mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length
padding_to_remove((n_fft - hop_length) / 2),
ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT
ifft_out(n_fft * 2 * 4, 0.0f) {
GGML_ASSERT(n_fft > 0 && hop_length > 0 && hop_length <= n_fft);
cache.fill_sin_cos_table(n_fft);
cache.fill_hann_window(n_fft, true);
}
+2 -2
View File
@@ -33,9 +33,9 @@ struct mtmd_audio_cache {
mtmd_audio_mel_filters filters;
void fill_sin_cos_table(int n);
void fill_sin_cos_table(uint32_t n);
void fill_hann_window(int length, bool periodic);
void fill_hann_window(uint32_t length, bool periodic);
// Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
// n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
+10
View File
@@ -127,6 +127,7 @@ struct decode_embd_batch {
std::vector<int8_t> logits;
llama_batch batch;
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
GGML_ASSERT(n_tokens > 0 && n_pos_per_embd > 0 && n_mmproj_embd > 0);
pos .resize(n_tokens * n_pos_per_embd);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
@@ -157,6 +158,7 @@ struct decode_embd_batch {
// M-RoPE for image
void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
GGML_ASSERT(n_pos_per_embd == 4);
GGML_ASSERT(nx > 0 && ny > 0 && nx * ny == batch.n_tokens);
seq_id_0[0] = seq_id;
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
@@ -192,6 +194,7 @@ struct decode_embd_batch {
}
llama_batch get_view(int offset, int n_tokens) {
GGML_ASSERT(offset >= 0 && n_tokens > 0 && offset + n_tokens <= batch.n_tokens);
llama_pos * pos_ptr;
pos_view.clear();
pos_view.reserve(n_tokens * n_pos_per_embd);
@@ -235,6 +238,7 @@ int32_t mtmd_helper_decode_image_chunk(
llama_seq_id seq_id,
int32_t n_batch,
llama_pos * new_n_past) {
GGML_ASSERT(n_batch > 0);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
@@ -312,6 +316,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
int32_t n_batch,
bool logits_last,
llama_pos * new_n_past) {
GGML_ASSERT(n_batch > 0);
int32_t ret;
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
auto chunk_type = mtmd_input_chunk_get_type(chunk);
@@ -508,6 +513,11 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *
fseek(f, 0, SEEK_END);
long file_size = ftell(f);
fseek(f, 0, SEEK_SET);
if (file_size < 0) {
LOG_ERR("Failed to get file size of %s\n", fname);
fclose(f);
return nullptr;
}
buf.resize(file_size);
size_t n_read = fread(buf.data(), 1, file_size, f);
+6 -2
View File
@@ -99,6 +99,8 @@ struct img_tool {
}
static void crop(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
GGML_ASSERT(x >= 0 && y >= 0 && w > 0 && h > 0);
GGML_ASSERT(x + w <= image.nx && y + h <= image.ny);
dst.nx = w;
dst.ny = h;
dst.buf.resize(3 * w * h);
@@ -196,6 +198,7 @@ struct img_tool {
private:
// Bilinear resize function
static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) {
GGML_ASSERT(src.nx >= 2 && src.ny >= 2);
dst.nx = target_width;
dst.ny = target_height;
dst.buf.resize(3 * target_width * target_height);
@@ -207,8 +210,8 @@ private:
for (int x = 0; x < target_width; x++) {
float px = x_ratio * x;
float py = y_ratio * y;
int x_floor = static_cast<int>(px);
int y_floor = static_cast<int>(py);
int x_floor = std::min(static_cast<int>(px), src.nx - 2);
int y_floor = std::min(static_cast<int>(py), src.ny - 2);
float x_lerp = px - x_floor;
float y_lerp = py - y_floor;
@@ -347,6 +350,7 @@ private:
// Returns: kernel size (ksize) - number of input pixels that contribute to each output pixel
auto precompute_weights = [&](int inSize, int outSize,
std::vector<int> & bounds, std::vector<int32_t> & weights) -> int {
GGML_ASSERT(inSize > 0 && outSize > 0);
double support, scale, filterscale;
double center, ww, ss;
int xx, x, ksize, xmin, xmax, xcnt;
+10 -1
View File
@@ -641,6 +641,11 @@ struct mtmd_tokenizer {
add_text(ctx->img_beg, true); // add image begin token
}
// sanity check
GGML_ASSERT(bitmap->nx > 0 && bitmap->ny > 0);
GGML_ASSERT(bitmap->data.size() == (size_t)bitmap->nx * bitmap->ny * 3);
GGML_ASSERT(ctx->image_preproc != nullptr);
// convert mtmd_bitmap to clip_image_u8
clip_image_u8_ptr img_u8(clip_image_u8_init());
img_u8->nx = bitmap->nx;
@@ -649,7 +654,6 @@ struct mtmd_tokenizer {
std::memcpy(img_u8->buf.data(), bitmap->data.data(), img_u8->nx * img_u8->ny * 3);
// preprocess image
GGML_ASSERT(ctx->image_preproc != nullptr);
clip_image_f32_batch batch_f32;
bool ok = ctx->image_preproc->preprocess(*img_u8, batch_f32);
if (!ok) {
@@ -773,6 +777,11 @@ struct mtmd_tokenizer {
add_text(ctx->aud_beg, true); // add audio begin token
}
// sanity check
GGML_ASSERT(ctx->audio_preproc != nullptr);
GGML_ASSERT(bitmap->data.size() > sizeof(float));
GGML_ASSERT(bitmap->data.size() % sizeof(float) == 0);
// preprocess audio
std::vector<mtmd_audio_mel> mel_spec_chunks;
const float * samples = (const float *)bitmap->data.data();
+23 -14
View File
@@ -13,6 +13,8 @@ add_library(${TARGET} STATIC
server-common.h
server-context.cpp
server-context.h
server-tools.cpp
server-tools.h
)
if (BUILD_SHARED_LIBS)
@@ -35,22 +37,29 @@ set(TARGET_SRCS
server-models.cpp
server-models.h
)
set(PUBLIC_ASSETS
index.html.gz
loading.html
)
foreach(asset ${PUBLIC_ASSETS})
set(input "${CMAKE_CURRENT_SOURCE_DIR}/public/${asset}")
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
list(APPEND TARGET_SRCS ${output})
add_custom_command(
DEPENDS "${input}"
OUTPUT "${output}"
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
option(LLAMA_BUILD_WEBUI "Build the embedded Web UI" ON)
if (LLAMA_BUILD_WEBUI)
set(PUBLIC_ASSETS
index.html.gz
loading.html
)
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
endforeach()
foreach(asset ${PUBLIC_ASSETS})
set(input "${CMAKE_CURRENT_SOURCE_DIR}/public/${asset}")
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
list(APPEND TARGET_SRCS ${output})
add_custom_command(
DEPENDS "${input}"
OUTPUT "${output}"
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
)
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
endforeach()
add_definitions(-DLLAMA_BUILD_WEBUI)
else()
endif()
add_executable(${TARGET} ${TARGET_SRCS})
install(TARGETS ${TARGET} RUNTIME)
+55
View File
@@ -125,6 +125,61 @@ The framework automatically starts a `llama-server` instance, sends requests, an
For detailed instructions, see the [test documentation](./tests/README.md).
### API for tools
This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future.
**GET /tools**
Get a list of tools, each tool has these fields:
- `tool` (string): the ID name of the tool, to be used in POST call. Example: `read_file`
- `display_name` (string): the name to be displayed on UI. Example: `Read file`
- `type` (string): always be `"builtin"` for now
- `permissions` (object): a mapping string --> boolean that indicates the permission required by this tool. This is useful for the UI to ask the user before calling the tool. For now, the only permission supported is `"write"`
- `definition` (object): the OAI-compat definition of this tool
**POST /tools**
Invoke a tool call, request body is a JSON object with:
- `tool` (string): the name of the tool
- `params` (object): a mapping from argument name (string) to argument value
Returns JSON object. There are two response formats:
Format 1: Plain text. The text will be placed into a field called `plain_text_response`, example:
```json
{
"plain_text_response": "this is a text response"
}
```
The client should extract this value and place it inside message content (note: content is no longer a JSON), example
```json
{
"role": "tool",
"content": "this is a text response"
}
```
Format 2: Normal JSON response, example:
```json
{
"error": "cannot open this file"
}
```
That requires `JSON.stringify` when formatted to message content:
```json
{
"role": "tool",
"content": "{\"error\":\"cannot open this file\"}"
}
```
### Notable Related PRs
- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443
+26 -1
View File
@@ -36,7 +36,6 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--license` | show source code license and dependencies |
| `-cl, --cache-list` | show list of models in cache |
| `--completion-bash` | print source-able bash completion script for llama.cpp |
| `--verbose-prompt` | print a verbose prompt before generation (default: false) |
| `-t, --threads N` | number of CPU threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
@@ -189,11 +188,13 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--tags STRING` | set model tags, comma-separated (informational, not used for routing)<br/>(env: LLAMA_ARG_TAGS) |
| `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
| `--reuse-port` | allow multiple sockets to bind to the same port (default: disabled)<br/>(env: LLAMA_ARG_REUSE_PORT) |
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
| `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )<br/>(env: LLAMA_ARG_API_PREFIX) |
| `--webui-config JSON` | JSON that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG) |
| `--webui-config-file PATH` | JSON file that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG_FILE) |
| `--webui-mcp-proxy, --no-webui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)<br/>(env: LLAMA_ARG_WEBUI_MCP_PROXY) |
| `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)<br/>specify "all" to enable all tools<br/>available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff<br/>(env: LLAMA_ARG_TOOLS) |
| `--webui, --no-webui` | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_WEBUI) |
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
| `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
@@ -293,6 +294,12 @@ It is currently available in the following endpoints:
For more details, please refer to [multimodal documentation](../../docs/multimodal.md)
### Built-in tools support
The server includes a set of built-in tools that enable the LLM to access the local file system directly from the Web UI.
To use this feature, start the server with `--tools all`. You can also enable only specific tools by passing a comma-separated list: `--tools name1,name2,...`. Run `--help` for the full list of available tool names.
## Build
`llama-server` is built alongside everything else from the root of the project
@@ -1438,6 +1445,14 @@ curl http://localhost:8080/v1/messages/count_tokens \
{"input_tokens": 10}
```
## Server built-in tools
The server exposes a REST API under `/tools` that allows the Web UI to call built-in tools. This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future.
**Please do NOT use this endpoint in a downstream application**
For further documentation about this endpoint, please refer to [server internal documentation](./README-dev.md)
## Using multiple models
`llama-server` can be launched in a **router mode** that exposes an API for dynamically loading and unloading models. The main process (the "router") automatically forwards each request to the appropriate model instance.
@@ -1760,6 +1775,16 @@ Apart from error types supported by OAI, we also have custom types that are spec
}
```
### Custom default Web UI preferences
You can specify default preferences for the web UI using `--webui-config <JSON config>` or `--webui-config-file <path to JSON config>`. For example, you can disable pasting long text as attachments and enable rendering Markdown in user messages with this command:
```bash
./llama-server -m model.gguf --webui-config '{"pasteLongTextToFileLen": 0, "renderUserContentAsMarkdown": true}'
```
You may find available preferences in [settings-config.ts](webui/src/lib/constants/settings-config.ts).
### Legacy completion web UI
A new chat-based UI has replaced the old completion-based since [this PR](https://github.com/ggml-org/llama.cpp/pull/10175). If you want to use the old completion, start the server with `--path ./tools/server/public_legacy`
Binary file not shown.
+1 -1
View File
@@ -1110,7 +1110,7 @@ json oaicompat_chat_params_parse(
reasoning_budget = json_value(body, "thinking_budget_tokens", -1);
}
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
if (!chat_params.thinking_end_tag.empty()) {
llama_params["reasoning_budget_tokens"] = reasoning_budget;
llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
+1 -1
View File
@@ -2493,7 +2493,7 @@ private:
bool has_mtmd = false;
// check if we should process the image
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
size_t n_tokens_out = 0;
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
+18 -1
View File
@@ -8,9 +8,11 @@
#include <string>
#include <thread>
#ifdef LLAMA_BUILD_WEBUI
// auto generated files (see README.md for details)
#include "index.html.gz.hpp"
#include "loading.html.hpp"
#endif
//
// HTTP implementation using cpp-httplib
@@ -110,6 +112,16 @@ bool server_http_context::init(const common_params & params) {
// set timeouts and change hostname and port
srv->set_read_timeout (params.timeout_read);
srv->set_write_timeout(params.timeout_write);
srv->set_socket_options([reuse_port = params.reuse_port](socket_t sock) {
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
if (reuse_port) {
#ifdef SO_REUSEPORT
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEPORT, 1);
#else
LOG_WRN("%s: SO_REUSEPORT is not supported\n", __func__);
#endif
}
});
if (params.api_keys.size() == 1) {
auto key = params.api_keys[0];
@@ -181,11 +193,14 @@ bool server_http_context::init(const common_params & params) {
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
bool ready = is_ready.load();
if (!ready) {
#ifdef LLAMA_BUILD_WEBUI
auto tmp = string_split<std::string>(req.path, '.');
if (req.path == "/" || tmp.back() == "html") {
res.status = 503;
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
} else {
} else
#endif
{
// no endpoints is allowed to be accessed when the server is not ready
// this is to prevent any data races or inconsistent states
res.status = 503;
@@ -255,6 +270,7 @@ bool server_http_context::init(const common_params & params) {
return 1;
}
} else {
#ifdef LLAMA_BUILD_WEBUI
// using embedded static index.html
srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
@@ -268,6 +284,7 @@ bool server_http_context::init(const common_params & params) {
}
return false;
});
#endif
}
}
return true;
+10 -12
View File
@@ -478,19 +478,17 @@ task_params server_task::params_from_json_cmpl(
// Parse reasoning budget sampler parameters
{
const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
if (budget >= 0) {
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
const auto message = json_value(data, "reasoning_budget_message", std::string());
params.sampling.reasoning_budget_tokens = budget;
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
const auto message = json_value(data, "reasoning_budget_message", std::string());
params.sampling.reasoning_budget_tokens = budget;
if (!start_tag.empty()) {
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
}
if (!end_tag.empty()) {
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
}
if (!start_tag.empty()) {
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
}
if (!end_tag.empty()) {
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
budget, params.sampling.generation_prompt.c_str(),
+768
View File
@@ -0,0 +1,768 @@
#include "server-tools.h"
#include <sheredom/subprocess.h>
#include <filesystem>
#include <fstream>
#include <regex>
#include <thread>
#include <chrono>
#include <atomic>
#include <cstring>
#include <climits>
namespace fs = std::filesystem;
//
// internal helpers
//
static std::vector<char *> to_cstr_vec(const std::vector<std::string> & v) {
std::vector<char *> r;
r.reserve(v.size() + 1);
for (const auto & s : v) {
r.push_back(const_cast<char *>(s.c_str()));
}
r.push_back(nullptr);
return r;
}
struct run_proc_result {
std::string output;
int exit_code = -1;
bool timed_out = false;
};
static run_proc_result run_process(
const std::vector<std::string> & args,
size_t max_output,
int timeout_secs) {
run_proc_result res;
subprocess_s proc;
auto argv = to_cstr_vec(args);
int options = subprocess_option_no_window
| subprocess_option_combined_stdout_stderr
| subprocess_option_inherit_environment
| subprocess_option_search_user_path;
if (subprocess_create(argv.data(), options, &proc) != 0) {
res.output = "failed to spawn process";
return res;
}
std::atomic<bool> done{false};
std::atomic<bool> timed_out{false};
std::thread timeout_thread([&]() {
auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_secs);
while (!done.load()) {
if (std::chrono::steady_clock::now() >= deadline) {
timed_out.store(true);
subprocess_terminate(&proc);
return;
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
});
FILE * f = subprocess_stdout(&proc);
std::string output;
bool truncated = false;
if (f) {
char buf[4096];
while (fgets(buf, sizeof(buf), f) != nullptr) {
if (!truncated) {
size_t len = strlen(buf);
if (output.size() + len <= max_output) {
output.append(buf, len);
} else {
output.append(buf, max_output - output.size());
truncated = true;
}
}
}
}
done.store(true);
if (timeout_thread.joinable()) {
timeout_thread.join();
}
subprocess_join(&proc, &res.exit_code);
subprocess_destroy(&proc);
res.output = output;
res.timed_out = timed_out.load();
if (truncated) {
res.output += "\n[output truncated]";
}
return res;
}
json server_tool::to_json() {
return {
{"display_name", display_name},
{"tool", name},
{"type", "builtin"},
{"permissions", json{
{"write", permission_write}
}},
{"definition", get_definition()},
};
}
//
// read_file: read a file with optional line range and line-number prefix
//
static constexpr size_t SERVER_TOOL_READ_FILE_MAX_SIZE = 16 * 1024; // 16 KB
struct server_tool_read_file : server_tool {
server_tool_read_file() {
name = "read_file";
display_name = "Read file";
permission_write = false;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Read the contents of a file. Optionally specify a 1-based line range. "
"If append_loc is true, each line is prefixed with its line number (e.g. \"1\u2192 ...\")."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "Path to the file"}}},
{"start_line", {{"type", "integer"}, {"description", "First line to read, 1-based (default: 1)"}}},
{"end_line", {{"type", "integer"}, {"description", "Last line to read, 1-based inclusive (default: end of file)"}}},
{"append_loc", {{"type", "boolean"}, {"description", "Prefix each line with its line number"}}},
}},
{"required", json::array({"path"})},
}},
}},
};
}
json invoke(json params) override {
std::string path = params.at("path").get<std::string>();
int start_line = json_value(params, "start_line", 1);
int end_line = json_value(params, "end_line", -1); // -1 = no limit
bool append_loc = json_value(params, "append_loc", false);
std::error_code ec;
uintmax_t file_size = fs::file_size(path, ec);
if (ec) {
return {{"error", "cannot stat file: " + ec.message()}};
}
if (file_size > SERVER_TOOL_READ_FILE_MAX_SIZE && end_line == -1) {
return {{"error", string_format(
"file too large (%zu bytes, max %zu). Use start_line/end_line to read a portion.",
(size_t)file_size, SERVER_TOOL_READ_FILE_MAX_SIZE)}};
}
std::ifstream f(path);
if (!f) {
return {{"error", "failed to open file: " + path}};
}
std::string result;
std::string line;
int lineno = 0;
while (std::getline(f, line)) {
lineno++;
if (lineno < start_line) continue;
if (end_line != -1 && lineno > end_line) break;
std::string out_line;
if (append_loc) {
out_line = std::to_string(lineno) + "\u2192 " + line + "\n";
} else {
out_line = line + "\n";
}
if (result.size() + out_line.size() > SERVER_TOOL_READ_FILE_MAX_SIZE) {
result += "[output truncated]";
break;
}
result += out_line;
}
return {{"plain_text_response", result}};
}
};
//
// file_glob_search: find files matching a glob pattern under a base directory
//
static constexpr size_t SERVER_TOOL_FILE_SEARCH_MAX_RESULTS = 100;
struct server_tool_file_glob_search : server_tool {
server_tool_file_glob_search() {
name = "file_glob_search";
display_name = "File search";
permission_write = false;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Recursively search for files matching a glob pattern under a directory."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "Base directory to search in"}}},
{"include", {{"type", "string"}, {"description", "Glob pattern for files to include (e.g. \"**/*.cpp\"). Default: **"}}},
{"exclude", {{"type", "string"}, {"description", "Glob pattern for files to exclude"}}},
}},
{"required", json::array({"path"})},
}},
}},
};
}
json invoke(json params) override {
std::string base = params.at("path").get<std::string>();
std::string include = json_value(params, "include", std::string("**"));
std::string exclude = json_value(params, "exclude", std::string(""));
std::ostringstream output_text;
size_t count = 0;
std::error_code ec;
for (const auto & entry : fs::recursive_directory_iterator(base,
fs::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) continue;
std::string rel = fs::relative(entry.path(), base, ec).string();
if (ec) continue;
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(include, rel)) continue;
if (!exclude.empty() && glob_match(exclude, rel)) continue;
output_text << entry.path().string() << "\n";
if (++count >= SERVER_TOOL_FILE_SEARCH_MAX_RESULTS) {
break;
}
}
output_text << "\n---\nTotal matches: " << count << "\n";
return {{"plain_text_response", output_text.str()}};
}
};
//
// grep_search: search for a regex pattern in files
//
static constexpr size_t SERVER_TOOL_GREP_SEARCH_MAX_RESULTS = 100;
struct server_tool_grep_search : server_tool {
server_tool_grep_search() {
name = "grep_search";
display_name = "Grep search";
permission_write = false;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Search for a regex pattern in files under a path. Returns matching lines."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "File or directory to search in"}}},
{"pattern", {{"type", "string"}, {"description", "Regular expression pattern to search for"}}},
{"include", {{"type", "string"}, {"description", "Glob pattern to filter files (default: **)"}}},
{"exclude", {{"type", "string"}, {"description", "Glob pattern to exclude files"}}},
{"return_line_numbers", {{"type", "boolean"}, {"description", "If true, include line numbers in results"}}},
}},
{"required", json::array({"path", "pattern"})},
}},
}},
};
}
json invoke(json params) override {
std::string path = params.at("path").get<std::string>();
std::string pat_str = params.at("pattern").get<std::string>();
std::string include = json_value(params, "include", std::string("**"));
std::string exclude = json_value(params, "exclude", std::string(""));
bool show_lineno = json_value(params, "return_line_numbers", false);
std::regex pattern;
try {
pattern = std::regex(pat_str);
} catch (const std::regex_error & e) {
return {{"error", std::string("invalid regex: ") + e.what()}};
}
std::ostringstream output_text;
size_t total = 0;
auto search_file = [&](const fs::path & fpath) {
std::ifstream f(fpath);
if (!f) return;
std::string line;
int lineno = 0;
while (std::getline(f, line) && total < SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) {
lineno++;
if (std::regex_search(line, pattern)) {
output_text << fpath.string() << ":";
if (show_lineno) {
output_text << lineno << ":";
}
output_text << line << "\n";
total++;
}
}
};
std::error_code ec;
if (fs::is_regular_file(path, ec)) {
search_file(path);
} else if (fs::is_directory(path, ec)) {
for (const auto & entry : fs::recursive_directory_iterator(path,
fs::directory_options::skip_permission_denied, ec)) {
if (!entry.is_regular_file()) continue;
if (total >= SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) break;
std::string rel = fs::relative(entry.path(), path, ec).string();
if (ec) continue;
std::replace(rel.begin(), rel.end(), '\\', '/');
if (!glob_match(include, rel)) continue;
if (!exclude.empty() && glob_match(exclude, rel)) continue;
search_file(entry.path());
}
} else {
return {{"error", "path does not exist: " + path}};
}
output_text << "\n\n---\nTotal matches: " << total << "\n";
return {{"plain_text_response", output_text.str()}};
}
};
//
// exec_shell_command: run an arbitrary shell command
//
static constexpr size_t SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE = 16 * 1024; // 16 KB
static constexpr int SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT = 60; // seconds
struct server_tool_exec_shell_command : server_tool {
server_tool_exec_shell_command() {
name = "exec_shell_command";
display_name = "Execute shell command";
permission_write = true;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Execute a shell command and return its output (stdout and stderr combined)."},
{"parameters", {
{"type", "object"},
{"properties", {
{"command", {{"type", "string"}, {"description", "Shell command to execute"}}},
{"timeout", {{"type", "integer"}, {"description", string_format("Timeout in seconds (default 10, max %d)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT)}}},
{"max_output_size", {{"type", "integer"}, {"description", string_format("Maximum output size in bytes (default %zu)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE)}}},
}},
{"required", json::array({"command"})},
}},
}},
};
}
json invoke(json params) override {
std::string command = params.at("command").get<std::string>();
int timeout = json_value(params, "timeout", 10);
size_t max_output = (size_t) json_value(params, "max_output_size", (int) SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
timeout = std::min(timeout, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT);
max_output = std::min(max_output, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
#ifdef _WIN32
std::vector<std::string> args = {"cmd", "/c", command};
#else
std::vector<std::string> args = {"sh", "-c", command};
#endif
auto res = run_process(args, max_output, timeout);
std::string text_output = res.output;
text_output += string_format("\n[exit code: %d]", res.exit_code);
if (res.timed_out) {
text_output += " [exit due to timed out]";
}
return {{"plain_text_response", text_output}};
}
};
//
// write_file: create or overwrite a file
//
struct server_tool_write_file : server_tool {
server_tool_write_file() {
name = "write_file";
display_name = "Write file";
permission_write = true;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Write content to a file, creating it (including parent directories) if it does not exist. May use with edit_file for more complex edits."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "Path of the file to write"}}},
{"content", {{"type", "string"}, {"description", "Content to write"}}},
}},
{"required", json::array({"path", "content"})},
}},
}},
};
}
json invoke(json params) override {
std::string path = params.at("path").get<std::string>();
std::string content = params.at("content").get<std::string>();
std::error_code ec;
fs::path fpath(path);
if (fpath.has_parent_path()) {
fs::create_directories(fpath.parent_path(), ec);
if (ec) {
return {{"error", "failed to create directories: " + ec.message()}};
}
}
std::ofstream f(path, std::ios::binary);
if (!f) {
return {{"error", "failed to open file for writing: " + path}};
}
f << content;
if (!f) {
return {{"error", "failed to write file: " + path}};
}
return {{"result", "file written successfully"}, {"path", path}, {"bytes", content.size()}};
}
};
//
// edit_file: edit file content via line-based changes
//
struct server_tool_edit_file : server_tool {
server_tool_edit_file() {
name = "edit_file";
display_name = "Edit file";
permission_write = true;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description",
"Edit a file by applying a list of line-based changes. "
"Each change targets a 1-based inclusive line range and has a mode: "
"\"replace\" (replace lines with content), "
"\"delete\" (remove lines, content must be empty string), "
"\"append\" (insert content after line_end). "
"Set line_start to -1 to target the end of file (line_end is ignored in that case). "
"Changes must not overlap. They are applied in reverse line order automatically."},
{"parameters", {
{"type", "object"},
{"properties", {
{"path", {{"type", "string"}, {"description", "Path to the file to edit"}}},
{"changes", {
{"type", "array"},
{"description", "List of changes to apply"},
{"items", {
{"type", "object"},
{"properties", {
{"mode", {{"type", "string"}, {"description", "\"replace\", \"delete\", or \"append\""}}},
{"line_start", {{"type", "integer"}, {"description", "First line of the range (1-based); use -1 for end of file"}}},
{"line_end", {{"type", "integer"}, {"description", "Last line of the range (1-based, inclusive); ignored when line_start is -1"}}},
{"content", {{"type", "string"}, {"description", "Content to insert; must be empty string for delete mode"}}},
}},
{"required", json::array({"mode", "line_start", "line_end", "content"})},
}},
}},
}},
{"required", json::array({"path", "changes"})},
}},
}},
};
}
json invoke(json params) override {
std::string path = params.at("path").get<std::string>();
const json & changes = params.at("changes");
if (!changes.is_array()) {
return {{"error", "\"changes\" must be an array"}};
}
// read file into lines
std::ifstream fin(path);
if (!fin) {
return {{"error", "failed to open file: " + path}};
}
std::vector<std::string> lines;
{
std::string line;
while (std::getline(fin, line)) {
lines.push_back(line);
}
}
fin.close();
// validate and collect changes, then sort descending by line_start
struct change_entry {
std::string mode;
int line_start; // 1-based
int line_end; // 1-based inclusive
std::string content;
};
std::vector<change_entry> entries;
entries.reserve(changes.size());
for (const auto & ch : changes) {
change_entry e;
e.mode = ch.at("mode").get<std::string>();
e.line_start = ch.at("line_start").get<int>();
e.line_end = ch.at("line_end").get<int>();
e.content = ch.at("content").get<std::string>();
if (e.mode != "replace" && e.mode != "delete" && e.mode != "append") {
return {{"error", "invalid mode \"" + e.mode + "\"; must be replace, delete, or append"}};
}
if (e.mode == "delete" && !e.content.empty()) {
return {{"error", "content must be empty string for delete mode"}};
}
int n = (int) lines.size();
if (e.line_start == -1) {
// -1 means end of file; line_end is ignored — normalize to point past last line
e.line_start = n + 1;
e.line_end = n + 1;
} else {
if (e.line_start < 1 || e.line_end < e.line_start) {
return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}};
}
if (e.line_end > n) {
return {{"error", string_format("line_end %d exceeds file length %d", e.line_end, n)}};
}
}
entries.push_back(std::move(e));
}
// sort descending so earlier-indexed changes don't shift later ones
std::sort(entries.begin(), entries.end(), [](const change_entry & a, const change_entry & b) {
return a.line_start > b.line_start;
});
// apply changes (0-based indices internally)
for (const auto & e : entries) {
int idx_start = e.line_start - 1; // 0-based
int idx_end = e.line_end - 1; // 0-based inclusive
// split content into lines (preserve trailing newline awareness)
std::vector<std::string> new_lines;
if (!e.content.empty()) {
std::istringstream ss(e.content);
std::string ln;
while (std::getline(ss, ln)) {
new_lines.push_back(ln);
}
// if content ends with \n, getline consumed it — no extra empty line needed
// if content does NOT end with \n, last line is still captured correctly
}
if (e.mode == "replace") {
// erase [idx_start, idx_end] and insert new_lines
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
lines.insert(lines.begin() + idx_start, new_lines.begin(), new_lines.end());
} else if (e.mode == "delete") {
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
} else { // append
// idx_end + 1 may equal lines.size() when line_start == -1 (end of file)
lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end());
}
}
// write file back
std::ofstream fout(path, std::ios::binary);
if (!fout) {
return {{"error", "failed to open file for writing: " + path}};
}
for (size_t i = 0; i < lines.size(); i++) {
fout << lines[i];
if (i + 1 < lines.size()) {
fout << "\n";
}
}
if (!lines.empty()) {
fout << "\n";
}
if (!fout) {
return {{"error", "failed to write file: " + path}};
}
return {{"result", "file edited successfully"}, {"path", path}, {"lines", (int) lines.size()}};
}
};
//
// apply_diff: apply a unified diff via git apply
//
struct server_tool_apply_diff : server_tool {
server_tool_apply_diff() {
name = "apply_diff";
display_name = "Apply diff";
permission_write = true;
}
json get_definition() override {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", "Apply a unified diff to edit one or more files using git apply. Use this instead of edit_file when the changes are complex."},
{"parameters", {
{"type", "object"},
{"properties", {
{"diff", {{"type", "string"}, {"description", "Unified diff content in git diff format"}}},
}},
{"required", json::array({"diff"})},
}},
}},
};
}
json invoke(json params) override {
std::string diff = params.at("diff").get<std::string>();
// write diff to a temporary file
static std::atomic<int> counter{0};
std::string tmp_path = (fs::temp_directory_path() /
("llama_patch_" + std::to_string(++counter) + ".patch")).string();
{
std::ofstream f(tmp_path, std::ios::binary);
if (!f) {
return {{"error", "failed to create temp patch file"}};
}
f << diff;
}
auto res = run_process({"git", "apply", tmp_path}, 4096, 10);
std::error_code ec;
fs::remove(tmp_path, ec);
if (res.exit_code != 0) {
return {{"error", "git apply failed (exit " + std::to_string(res.exit_code) + "): " + res.output}};
}
return {{"result", "patch applied successfully"}};
}
};
//
// public API
//
static std::vector<std::unique_ptr<server_tool>> build_tools() {
std::vector<std::unique_ptr<server_tool>> tools;
tools.push_back(std::make_unique<server_tool_read_file>());
tools.push_back(std::make_unique<server_tool_file_glob_search>());
tools.push_back(std::make_unique<server_tool_grep_search>());
tools.push_back(std::make_unique<server_tool_exec_shell_command>());
tools.push_back(std::make_unique<server_tool_write_file>());
tools.push_back(std::make_unique<server_tool_edit_file>());
tools.push_back(std::make_unique<server_tool_apply_diff>());
return tools;
}
void server_tools::setup(const std::vector<std::string> & enabled_tools) {
if (!enabled_tools.empty()) {
std::unordered_set<std::string> enabled_set(enabled_tools.begin(), enabled_tools.end());
auto all_tools = build_tools();
tools.clear();
for (auto & t : all_tools) {
if (enabled_set.count(t->name) > 0 || enabled_set.count("all") > 0) {
tools.push_back(std::move(t));
}
}
}
handle_get = [this](const server_http_req &) -> server_http_res_ptr {
auto res = std::make_unique<server_http_res>();
try {
json result = json::array();
for (const auto & t : tools) {
result.push_back(t->to_json());
}
res->data = safe_json_to_str(result);
} catch (const std::exception & e) {
SRV_ERR("got exception: %s\n", e.what());
res->status = 500;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
}
return res;
};
handle_post = [this](const server_http_req & req) -> server_http_res_ptr {
auto res = std::make_unique<server_http_res>();
try {
json body = json::parse(req.body);
std::string tool_name = body.at("tool").get<std::string>();
json params = body.value("params", json::object());
json result = invoke(tool_name, params);
res->data = safe_json_to_str(result);
} catch (const json::exception & e) {
res->status = 400;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
} catch (const std::exception & e) {
SRV_ERR("got exception: %s\n", e.what());
res->status = 500;
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
}
return res;
};
}
json server_tools::invoke(const std::string & name, const json & params) {
for (auto & t : tools) {
if (t->name == name) {
return t->invoke(params);
}
}
return {{"error", "unknown tool: " + name}};
}
+26
View File
@@ -0,0 +1,26 @@
#pragma once
#include "server-common.h"
#include "server-http.h"
struct server_tool {
std::string name;
std::string display_name;
bool permission_write = false;
virtual ~server_tool() = default;
virtual json get_definition() = 0;
virtual json invoke(json params) = 0;
json to_json();
};
struct server_tools {
std::vector<std::unique_ptr<server_tool>> tools;
void setup(const std::vector<std::string> & enabled_tools);
json invoke(const std::string & name, const json & params);
server_http_context::handler_t handle_get;
server_http_context::handler_t handle_post;
};
+12
View File
@@ -2,6 +2,7 @@
#include "server-http.h"
#include "server-models.h"
#include "server-cors-proxy.h"
#include "server-tools.h"
#include "arg.h"
#include "common.h"
@@ -124,6 +125,7 @@ int main(int argc, char ** argv) {
// register API routes
server_routes routes(params, ctx_server);
server_tools tools;
bool is_router_server = params.model.path.empty();
std::optional<server_models_routes> models_routes{};
@@ -211,6 +213,16 @@ int main(int argc, char ** argv) {
ctx_http.get ("/cors-proxy", ex_wrapper(proxy_handler_get));
ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post));
}
// EXPERIMENTAL built-in tools
if (!params.server_tools.empty()) {
tools.setup(params.server_tools);
SRV_WRN("%s", "-----------------\n");
SRV_WRN("%s", "Built-in tools are enabled, do not expose server to untrusted environments\n");
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be changed in the future\n");
SRV_WRN("%s", "-----------------\n");
ctx_http.get ("/tools", ex_wrapper(tools.handle_get));
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
}
//
// Start the server
+9 -1
View File
@@ -288,7 +288,15 @@ class ServerProcess:
server_instances.remove(self)
if self.process:
print(f"Stopping server with pid={self.process.pid}")
self.process.kill()
self.process.terminate()
try:
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(f"Server pid={self.process.pid} did not terminate in time, killing")
self.process.kill()
self.process.wait(timeout=5)
except Exception as e:
print(f"Error waiting for server: {e}")
self.process = None
def make_request(
@@ -0,0 +1,54 @@
/**
* Svelte action that fades in an element when it enters the viewport.
* Uses IntersectionObserver for efficient viewport detection.
*
* If skipIfVisible is set and the element is already visible in the viewport
* when the action attaches (e.g. a markdown block promoted from unstable
* during streaming), the fade is skipped entirely to avoid a flash.
*/
export function fadeInView(
node: HTMLElement,
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
) {
const { duration = 300, y = 0, skipIfVisible = false } = options;
if (skipIfVisible) {
const rect = node.getBoundingClientRect();
const isAlreadyVisible =
rect.top < window.innerHeight &&
rect.bottom > 0 &&
rect.left < window.innerWidth &&
rect.right > 0;
if (isAlreadyVisible) {
return;
}
}
node.style.opacity = '0';
node.style.transform = `translateY(${y}px)`;
node.style.transition = `opacity ${duration}ms ease-out, transform ${duration}ms ease-out`;
$effect(() => {
const observer = new IntersectionObserver(
(entries) => {
for (const entry of entries) {
if (entry.isIntersecting) {
requestAnimationFrame(() => {
node.style.opacity = '1';
node.style.transform = 'translateY(0)';
});
observer.disconnect();
}
}
},
{ threshold: 0.05 }
);
observer.observe(node);
return () => {
observer.disconnect();
};
});
}
@@ -10,9 +10,9 @@
ModelsSelector,
ModelsSelectorSheet
} from '$lib/components/app';
import { DialogChatSettings } from '$lib/components/app/dialogs';
import { SETTINGS_SECTION_TITLES } from '$lib/constants';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { getChatSettingsDialogContext } from '$lib/contexts';
import { FileTypeCategory } from '$lib/enums';
import { getFileTypeCategory } from '$lib/utils';
import { config } from '$lib/stores/settings.svelte';
@@ -169,7 +169,7 @@
selectorModelRef?.open();
}
let showChatSettingsDialogWithMcpSection = $state(false);
const chatSettingsDialog = getChatSettingsDialogContext();
let hasMcpPromptsSupport = $derived.by(() => {
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
@@ -197,7 +197,7 @@
{onSystemPromptClick}
{onMcpPromptClick}
{onMcpResourcesClick}
onMcpSettingsClick={() => (showChatSettingsDialogWithMcpSection = true)}
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
/>
{:else}
<ChatFormActionAttachmentsDropdown
@@ -210,13 +210,13 @@
{onSystemPromptClick}
{onMcpPromptClick}
{onMcpResourcesClick}
onMcpSettingsClick={() => (showChatSettingsDialogWithMcpSection = true)}
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
/>
{/if}
<McpServersSelector
{disabled}
onSettingsClick={() => (showChatSettingsDialogWithMcpSection = true)}
onSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
/>
</div>
@@ -265,9 +265,3 @@
/>
{/if}
</div>
<DialogChatSettings
open={showChatSettingsDialogWithMcpSection}
onOpenChange={(open) => (showChatSettingsDialogWithMcpSection = open)}
initialSection={SETTINGS_SECTION_TITLES.MCP}
/>
@@ -180,6 +180,10 @@
chatActions.continueAssistantMessage(message);
}
function handleForkConversation(options: { name: string; includeAttachments: boolean }) {
chatActions.forkConversation(message, options);
}
function handleNavigateToSibling(siblingId: string) {
chatActions.navigateToSibling(siblingId);
}
@@ -285,6 +289,7 @@
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onForkConversation={handleForkConversation}
onNavigateToSibling={handleNavigateToSibling}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
{showDeleteDialog}
@@ -303,6 +308,7 @@
onCopy={handleCopy}
onDelete={handleDelete}
onEdit={handleEdit}
onForkConversation={handleForkConversation}
onNavigateToSibling={handleNavigateToSibling}
onRegenerate={handleRegenerate}
onShowDeleteDialogChange={handleShowDeleteDialogChange}
@@ -1,12 +1,16 @@
<script lang="ts">
import { Edit, Copy, RefreshCw, Trash2, ArrowRight } from '@lucide/svelte';
import { Edit, Copy, RefreshCw, Trash2, ArrowRight, GitBranch } from '@lucide/svelte';
import {
ActionIcon,
ChatMessageBranchingControls,
DialogConfirmation
} from '$lib/components/app';
import { Switch } from '$lib/components/ui/switch';
import { Checkbox } from '$lib/components/ui/checkbox';
import Input from '$lib/components/ui/input/input.svelte';
import Label from '$lib/components/ui/label/label.svelte';
import { MessageRole } from '$lib/enums';
import { activeConversation } from '$lib/stores/conversations.svelte';
interface Props {
role: MessageRole.USER | MessageRole.ASSISTANT;
@@ -24,6 +28,7 @@
onEdit?: () => void;
onRegenerate?: () => void;
onContinue?: () => void;
onForkConversation?: (options: { name: string; includeAttachments: boolean }) => void;
onDelete: () => void;
onConfirmDelete: () => void;
onNavigateToSibling?: (siblingId: string) => void;
@@ -42,6 +47,7 @@
onConfirmDelete,
onContinue,
onDelete,
onForkConversation,
onNavigateToSibling,
onShowDeleteDialogChange,
onRegenerate,
@@ -53,10 +59,27 @@
onRawOutputToggle
}: Props = $props();
let showForkDialog = $state(false);
let forkName = $state('');
let forkIncludeAttachments = $state(true);
function handleConfirmDelete() {
onConfirmDelete();
onShowDeleteDialogChange(false);
}
function handleOpenForkDialog() {
const conv = activeConversation();
forkName = `Fork of ${conv?.name ?? 'Conversation'}`;
forkIncludeAttachments = true;
showForkDialog = true;
}
function handleConfirmFork() {
onForkConversation?.({ name: forkName.trim(), includeAttachments: forkIncludeAttachments });
showForkDialog = false;
}
</script>
<div class="relative {justify === 'start' ? 'mt-2' : ''} flex h-6 items-center justify-between">
@@ -86,6 +109,10 @@
<ActionIcon icon={ArrowRight} tooltip="Continue" onclick={onContinue} />
{/if}
{#if onForkConversation}
<ActionIcon icon={GitBranch} tooltip="Fork conversation" onclick={handleOpenForkDialog} />
{/if}
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
</div>
</div>
@@ -116,3 +143,42 @@
onConfirm={handleConfirmDelete}
onCancel={() => onShowDeleteDialogChange(false)}
/>
<DialogConfirmation
bind:open={showForkDialog}
title="Fork Conversation"
description="Create a new conversation branching from this message."
confirmText="Fork"
cancelText="Cancel"
icon={GitBranch}
onConfirm={handleConfirmFork}
onCancel={() => (showForkDialog = false)}
>
<div class="flex flex-col gap-4 py-2">
<div class="flex flex-col gap-2">
<Label for="fork-name">Title</Label>
<Input
id="fork-name"
class="text-foreground"
placeholder="Enter fork name"
type="text"
bind:value={forkName}
/>
</div>
<div class="flex items-center gap-2">
<Checkbox
id="fork-attachments"
checked={forkIncludeAttachments}
onCheckedChange={(checked) => {
forkIncludeAttachments = checked === true;
}}
/>
<Label for="fork-attachments" class="cursor-pointer text-sm font-normal">
Include all attachments
</Label>
</div>
</div>
</DialogConfirmation>
@@ -3,14 +3,12 @@
ChatMessageAgenticContent,
ChatMessageActions,
ChatMessageStatistics,
MarkdownContent,
ModelBadge,
ModelsSelector
} from '$lib/components/app';
import { getMessageEditContext } from '$lib/contexts';
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
import { agenticStreamingToolCall } from '$lib/stores/agentic.svelte';
import { autoResizeTextarea, copyToClipboard, isIMEComposing } from '$lib/utils';
import { tick } from 'svelte';
import { fade } from 'svelte/transition';
@@ -41,6 +39,7 @@
onContinue?: () => void;
onDelete: () => void;
onEdit?: () => void;
onForkConversation?: (options: { name: string; includeAttachments: boolean }) => void;
onNavigateToSibling?: (siblingId: string) => void;
onRegenerate: (modelOverride?: string) => void;
onShowDeleteDialogChange: (show: boolean) => void;
@@ -60,6 +59,7 @@
onCopy,
onDelete,
onEdit,
onForkConversation,
onNavigateToSibling,
onRegenerate,
onShowDeleteDialogChange,
@@ -87,13 +87,7 @@
const hasAgenticMarkers = $derived(
messageContent?.includes(AGENTIC_TAGS.TOOL_CALL_START) ?? false
);
const hasStreamingToolCall = $derived(
isChatStreaming() && agenticStreamingToolCall(message.convId) !== null
);
const hasReasoningMarkers = $derived(messageContent?.includes(REASONING_TAGS.START) ?? false);
const isStructuredContent = $derived(
hasAgenticMarkers || hasReasoningMarkers || hasStreamingToolCall
);
const processingState = useProcessingState();
let currentConfig = $derived(config());
@@ -256,15 +250,13 @@
{:else if message.role === MessageRole.ASSISTANT}
{#if showRawOutput}
<pre class="raw-output">{messageContent || ''}</pre>
{:else if isStructuredContent}
{:else}
<ChatMessageAgenticContent
content={messageContent || ''}
isStreaming={isChatStreaming()}
highlightTurns={highlightAgenticTurns}
{message}
/>
{:else}
<MarkdownContent content={messageContent || ''} attachments={message.extra} />
{/if}
{:else}
<div class="text-sm whitespace-pre-wrap">
@@ -355,6 +347,7 @@
onContinue={currentConfig.enableContinueGeneration && !hasReasoningMarkers
? onContinue
: undefined}
{onForkConversation}
{onDelete}
{onConfirmDelete}
{onNavigateToSibling}
@@ -21,6 +21,7 @@
onEdit: () => void;
onDelete: () => void;
onConfirmDelete: () => void;
onForkConversation?: (options: { name: string; includeAttachments: boolean }) => void;
onShowDeleteDialogChange: (show: boolean) => void;
onNavigateToSibling?: (siblingId: string) => void;
onCopy: () => void;
@@ -35,6 +36,7 @@
onEdit,
onDelete,
onConfirmDelete,
onForkConversation,
onShowDeleteDialogChange,
onNavigateToSibling,
onCopy
@@ -114,6 +116,7 @@
{onCopy}
{onDelete}
{onEdit}
{onForkConversation}
{onNavigateToSibling}
{onShowDeleteDialogChange}
{siblingInfo}
@@ -1,4 +1,5 @@
<script lang="ts">
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { ChatMessage } from '$lib/components/app';
import { setChatActionsContext } from '$lib/contexts';
import { MessageRole } from '$lib/enums';
@@ -78,6 +79,13 @@
onUserAction?.();
await chatStore.continueAssistantMessage(message.id);
refreshAllMessages();
},
forkConversation: async (
message: DatabaseMessage,
options: { name: string; includeAttachments: boolean }
) => {
await conversationsStore.forkConversation(message.id, options);
}
});
@@ -140,13 +148,18 @@
});
</script>
<div class="flex h-full flex-col space-y-10 pt-24 {className}" style="height: auto; ">
<div
class="flex h-full flex-col space-y-10 pt-24 {className}"
style="height: auto; min-height: calc(100dvh - 14rem);"
>
{#each displayMessages as { message, isLastAssistantMessage, siblingInfo } (message.id)}
<ChatMessage
class="mx-auto w-full max-w-[48rem]"
{message}
{isLastAssistantMessage}
{siblingInfo}
/>
<div use:fadeInView>
<ChatMessage
class="mx-auto w-full max-w-[48rem]"
{message}
{isLastAssistantMessage}
{siblingInfo}
/>
</div>
{/each}
</div>
@@ -12,7 +12,6 @@
} from '$lib/components/app';
import * as Alert from '$lib/components/ui/alert';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import { INITIAL_SCROLL_DELAY } from '$lib/constants';
import { KeyboardKey } from '$lib/enums';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import {
@@ -48,7 +47,7 @@
let showFileErrorDialog = $state(false);
let uploadedFiles = $state<ChatUploadedFile[]>([]);
const autoScroll = createAutoScrollController();
const autoScroll = createAutoScrollController({ isColumnReverse: true });
let fileErrorData = $state<{
generallyUnsupported: File[];
@@ -310,13 +309,15 @@
afterNavigate(() => {
if (!disableAutoScroll) {
setTimeout(() => autoScroll.scrollToBottom('instant'), INITIAL_SCROLL_DELAY);
autoScroll.enable();
}
});
onMount(() => {
autoScroll.startObserving();
if (!disableAutoScroll) {
setTimeout(() => autoScroll.scrollToBottom('instant'), INITIAL_SCROLL_DELAY);
autoScroll.enable();
}
const pendingDraft = chatStore.consumePendingDraft();
@@ -333,10 +334,6 @@
$effect(() => {
autoScroll.setDisabled(disableAutoScroll);
});
$effect(() => {
autoScroll.updateInterval(isCurrentConversationLoading);
});
</script>
{#if isDragOver}
@@ -351,7 +348,7 @@
<div
bind:this={chatScrollContainer}
aria-label="Chat interface with file drop zone"
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
class="flex h-full flex-col-reverse overflow-y-auto px-4 md:px-6"
ondragenter={handleDragEnter}
ondragleave={handleDragLeave}
ondragover={handleDragOver}
@@ -359,57 +356,59 @@
onscroll={handleScroll}
role="main"
>
<ChatMessages
class="mb-16 md:mb-24"
messages={activeMessages()}
onUserAction={() => {
autoScroll.enable();
autoScroll.scrollToBottom();
}}
/>
<div class="flex flex-col">
<ChatMessages
class="mb-16 md:mb-24"
messages={activeMessages()}
onUserAction={() => {
autoScroll.enable();
autoScroll.scrollToBottom();
}}
/>
<div
class="pointer-events-none sticky right-0 bottom-4 left-0 mt-auto"
in:slide={{ duration: 150, axis: 'y' }}
>
<ChatScreenProcessingInfo />
<div
class="pointer-events-none sticky right-0 bottom-4 left-0 mt-auto"
in:slide={{ duration: 150, axis: 'y' }}
>
<ChatScreenProcessingInfo />
{#if hasPropsError}
<div
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
in:fly={{ y: 10, duration: 250 }}
>
<Alert.Root variant="destructive">
<AlertTriangle class="h-4 w-4" />
<Alert.Title class="flex items-center justify-between">
<span>Server unavailable</span>
<button
onclick={() => serverStore.fetch()}
disabled={isServerLoading}
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
>
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
{isServerLoading ? 'Retrying...' : 'Retry'}
</button>
</Alert.Title>
<Alert.Description>{serverError()}</Alert.Description>
</Alert.Root>
{#if hasPropsError}
<div
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
in:fly={{ y: 10, duration: 250 }}
>
<Alert.Root variant="destructive">
<AlertTriangle class="h-4 w-4" />
<Alert.Title class="flex items-center justify-between">
<span>Server unavailable</span>
<button
onclick={() => serverStore.fetch()}
disabled={isServerLoading}
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
>
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
{isServerLoading ? 'Retrying...' : 'Retry'}
</button>
</Alert.Title>
<Alert.Description>{serverError()}</Alert.Description>
</Alert.Root>
</div>
{/if}
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
<ChatScreenForm
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
showHelperText={false}
bind:uploadedFiles
/>
</div>
{/if}
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
<ChatScreenForm
disabled={hasPropsError || isEditing()}
{initialMessage}
isLoading={isCurrentConversationLoading}
onFileRemove={handleFileRemove}
onFileUpload={handleFileUpload}
onSend={handleSendMessage}
onStop={() => chatStore.stopGeneration()}
onSystemPromptAdd={handleSystemPromptAdd}
showHelperText={false}
bind:uploadedFiles
/>
</div>
</div>
</div>
@@ -1,16 +1,11 @@
<script lang="ts">
import { Settings } from '@lucide/svelte';
import { DialogChatSettings } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { useSidebar } from '$lib/components/ui/sidebar';
import { getChatSettingsDialogContext } from '$lib/contexts';
const sidebar = useSidebar();
let settingsOpen = $state(false);
function toggleSettings() {
settingsOpen = true;
}
const chatSettingsDialog = getChatSettingsDialogContext();
</script>
<header
@@ -22,12 +17,10 @@
<Button
variant="ghost"
size="icon-lg"
onclick={toggleSettings}
onclick={() => chatSettingsDialog.open()}
class="rounded-full backdrop-blur-lg"
>
<Settings class="h-4 w-4" />
</Button>
</div>
</header>
<DialogChatSettings open={settingsOpen} onOpenChange={(open) => (settingsOpen = open)} />
@@ -1,13 +1,18 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { page } from '$app/state';
import { Trash2 } from '@lucide/svelte';
import { Trash2, Pencil } from '@lucide/svelte';
import { ChatSidebarConversationItem, DialogConfirmation } from '$lib/components/app';
import { Checkbox } from '$lib/components/ui/checkbox';
import Label from '$lib/components/ui/label/label.svelte';
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
import * as Sidebar from '$lib/components/ui/sidebar';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import Input from '$lib/components/ui/input/input.svelte';
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
import {
conversationsStore,
conversations,
buildConversationTree
} from '$lib/stores/conversations.svelte';
import { chatStore } from '$lib/stores/chat.svelte';
import { getPreviewText } from '$lib/utils';
import ChatSidebarActions from './ChatSidebarActions.svelte';
@@ -18,6 +23,7 @@
let isSearchModeActive = $state(false);
let searchQuery = $state('');
let showDeleteDialog = $state(false);
let deleteWithForks = $state(false);
let showEditDialog = $state(false);
let selectedConversation = $state<DatabaseConversation | null>(null);
let editedName = $state('');
@@ -35,10 +41,30 @@
return conversations();
});
let conversationTree = $derived(buildConversationTree(filteredConversations));
let selectedConversationHasDescendants = $derived.by(() => {
if (!selectedConversation) return false;
const allConvs = conversations();
const queue = [selectedConversation.id];
while (queue.length > 0) {
const parentId = queue.pop()!;
for (const c of allConvs) {
if (c.forkedFromConversationId === parentId) return true;
}
}
return false;
});
async function handleDeleteConversation(id: string) {
const conversation = conversations().find((conv) => conv.id === id);
if (conversation) {
selectedConversation = conversation;
deleteWithForks = false;
showDeleteDialog = true;
}
}
@@ -54,11 +80,14 @@
function handleConfirmDelete() {
if (selectedConversation) {
const convId = selectedConversation.id;
const withForks = deleteWithForks;
showDeleteDialog = false;
setTimeout(() => {
conversationsStore.deleteConversation(selectedConversation.id);
selectedConversation = null;
conversationsStore.deleteConversation(convId, {
deleteWithForks: withForks
});
}, 100); // Wait for animation to finish
}
}
@@ -110,7 +139,7 @@
</script>
<ScrollArea class="h-[100vh]">
<Sidebar.Header class=" top-0 z-10 gap-6 bg-sidebar/50 px-4 py-4 pb-2 backdrop-blur-lg md:sticky">
<Sidebar.Header class=" top-0 z-10 gap-4 bg-sidebar/50 p-4 pb-2 backdrop-blur-lg md:sticky">
<a href="#/" onclick={handleMobileSidebarItemClick}>
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">llama.cpp</h1>
</a>
@@ -118,7 +147,7 @@
<ChatSidebarActions {handleMobileSidebarItemClick} bind:isSearchModeActive bind:searchQuery />
</Sidebar.Header>
<Sidebar.Group class="mt-4 space-y-2 p-0 px-4">
<Sidebar.Group class="mt-2 space-y-2 p-0 px-4">
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
<Sidebar.GroupLabel>
{isSearchModeActive ? 'Search results' : 'Conversations'}
@@ -127,15 +156,17 @@
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each filteredConversations as conversation (conversation.id)}
<Sidebar.MenuItem class="mb-1">
{#each conversationTree as { conversation, depth } (conversation.id)}
<Sidebar.MenuItem class="mb-1 p-0">
<ChatSidebarConversationItem
conversation={{
id: conversation.id,
name: conversation.name,
lastModified: conversation.lastModified,
currNode: conversation.currNode
currNode: conversation.currNode,
forkedFromConversationId: conversation.forkedFromConversationId
}}
{depth}
{handleMobileSidebarItemClick}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
@@ -146,7 +177,7 @@
</Sidebar.MenuItem>
{/each}
{#if filteredConversations.length === 0}
{#if conversationTree.length === 0}
<div class="px-2 py-4 text-center">
<p class="mb-4 p-4 text-sm text-muted-foreground">
{searchQuery.length > 0
@@ -177,35 +208,40 @@
showDeleteDialog = false;
selectedConversation = null;
}}
/>
>
{#if selectedConversationHasDescendants}
<div class="flex items-center gap-2 py-2">
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
<AlertDialog.Root bind:open={showEditDialog}>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title>Edit Conversation Name</AlertDialog.Title>
<AlertDialog.Description>
<Input
class="mt-4 text-foreground"
onkeydown={(e) => {
if (e.key === 'Enter') {
e.preventDefault();
handleConfirmEdit();
}
}}
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</AlertDialog.Description>
</AlertDialog.Header>
<AlertDialog.Footer>
<AlertDialog.Cancel
onclick={() => {
showEditDialog = false;
selectedConversation = null;
}}>Cancel</AlertDialog.Cancel
>
<AlertDialog.Action onclick={handleConfirmEdit}>Save</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
</div>
{/if}
</DialogConfirmation>
<DialogConfirmation
bind:open={showEditDialog}
title="Edit Conversation Name"
description=""
confirmText="Save"
cancelText="Cancel"
icon={Pencil}
onConfirm={handleConfirmEdit}
onCancel={() => {
showEditDialog = false;
selectedConversation = null;
}}
onKeydown={(e) => {
if (e.key === 'Enter') {
e.preventDefault();
e.stopImmediatePropagation();
handleConfirmEdit();
}
}}
>
<Input
class="text-foreground"
placeholder="Enter a new name"
type="text"
bind:value={editedName}
/>
</DialogConfirmation>
@@ -3,6 +3,9 @@
import { KeyboardShortcutInfo } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { Input } from '$lib/components/ui/input';
import { McpLogo } from '$lib/components/app';
import { SETTINGS_SECTION_TITLES } from '$lib/constants';
import { getChatSettingsDialogContext } from '$lib/contexts';
interface Props {
handleMobileSidebarItemClick: () => void;
@@ -18,6 +21,8 @@
let searchInput: HTMLInputElement | null = $state(null);
const chatSettingsDialog = getChatSettingsDialogContext();
function handleSearchModeDeactivate() {
isSearchModeActive = false;
searchQuery = '';
@@ -30,7 +35,7 @@
});
</script>
<div class="space-y-0.5">
<div class="my-1 space-y-1">
{#if isSearchModeActive}
<div class="relative">
<Search class="absolute top-2.5 left-2 h-4 w-4 text-muted-foreground" />
@@ -50,13 +55,14 @@
</div>
{:else}
<Button
class="w-full justify-between hover:[&>kbd]:opacity-100"
class="w-full justify-between backdrop-blur-none! hover:[&>kbd]:opacity-100"
href="?new_chat=true#/"
onclick={handleMobileSidebarItemClick}
variant="ghost"
>
<div class="flex items-center gap-2">
<SquarePen class="h-4 w-4" />
New chat
</div>
@@ -64,7 +70,7 @@
</Button>
<Button
class="w-full justify-between hover:[&>kbd]:opacity-100"
class="w-full justify-between backdrop-blur-none! hover:[&>kbd]:opacity-100"
onclick={() => {
isSearchModeActive = true;
}}
@@ -72,10 +78,25 @@
>
<div class="flex items-center gap-2">
<Search class="h-4 w-4" />
Search conversations
Search
</div>
<KeyboardShortcutInfo keys={['cmd', 'k']} />
</Button>
<Button
class="w-full justify-between backdrop-blur-none! hover:[&>kbd]:opacity-100"
onclick={() => {
chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP);
}}
variant="ghost"
>
<div class="flex items-center gap-2">
<McpLogo class="h-4 w-4" />
MCP Servers
</div>
</Button>
{/if}
</div>
@@ -1,13 +1,23 @@
<script lang="ts">
import { Trash2, Pencil, MoreHorizontal, Download, Loader2, Square } from '@lucide/svelte';
import {
Trash2,
Pencil,
MoreHorizontal,
Download,
Loader2,
Square,
GitBranch
} from '@lucide/svelte';
import { DropdownMenuActions } from '$lib/components/app';
import * as Tooltip from '$lib/components/ui/tooltip';
import { FORK_TREE_DEPTH_PADDING } from '$lib/constants';
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import { onMount } from 'svelte';
interface Props {
isActive?: boolean;
depth?: number;
conversation: DatabaseConversation;
handleMobileSidebarItemClick?: () => void;
onDelete?: (id: string) => void;
@@ -23,7 +33,8 @@
onEdit,
onSelect,
onStop,
isActive = false
isActive = false,
depth = 0
}: Props = $props();
let renderActionsDropdown = $state(false);
@@ -88,14 +99,34 @@
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
<button
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg px-3 py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
? 'bg-foreground/5 text-accent-foreground'
: ''}"
: ''} px-3"
onclick={handleSelect}
onmouseover={handleMouseOver}
onmouseleave={handleMouseLeave}
>
<div class="flex min-w-0 flex-1 items-center gap-2">
<div
class="flex min-w-0 flex-1 items-center gap-2"
style:padding-left="{depth * FORK_TREE_DEPTH_PADDING}px"
>
{#if depth > 0}
<Tooltip.Root>
<Tooltip.Trigger>
<a
href="#/chat/{conversation.forkedFromConversationId}"
class="flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
>
<GitBranch class="h-3.5 w-3.5" />
</a>
</Tooltip.Trigger>
<Tooltip.Content>
<p>See parent conversation</p>
</Tooltip.Content>
</Tooltip.Root>
{/if}
{#if isLoading}
<Tooltip.Root>
<Tooltip.Trigger>
@@ -36,6 +36,7 @@
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import type { DatabaseMessageExtra } from '$lib/types/database';
import { config } from '$lib/stores/settings.svelte';
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
interface Props {
attachments?: DatabaseMessageExtra[];
@@ -598,7 +599,7 @@
: ''}"
>
{#each renderedBlocks as block (block.id)}
<div class="markdown-block" data-block-id={block.id}>
<div class="markdown-block" data-block-id={block.id} use:fadeInView={{ skipIfVisible: true }}>
<!-- eslint-disable-next-line no-at-html-tags -->
{@html block.html}
</div>
@@ -651,7 +652,6 @@
/>
<style>
.markdown-block,
.markdown-block--unstable {
display: contents;
}
@@ -1,6 +1,6 @@
<script lang="ts">
import * as AlertDialog from '$lib/components/ui/alert-dialog';
import type { Component } from 'svelte';
import type { Component, Snippet } from 'svelte';
import { KeyboardKey } from '$lib/enums';
interface Props {
@@ -14,6 +14,7 @@
onConfirm: () => void;
onCancel: () => void;
onKeydown?: (event: KeyboardEvent) => void;
children?: Snippet;
}
let {
@@ -26,7 +27,8 @@
icon,
onConfirm,
onCancel,
onKeydown
onKeydown,
children
}: Props = $props();
function handleKeydown(event: KeyboardEvent) {
@@ -60,6 +62,10 @@
</AlertDialog.Description>
</AlertDialog.Header>
{#if children}
{@render children()}
{/if}
<AlertDialog.Footer>
<AlertDialog.Cancel onclick={onCancel}>{cancelText}</AlertDialog.Cancel>
<AlertDialog.Action
@@ -1,3 +1,2 @@
export const AUTO_SCROLL_INTERVAL = 100;
export const INITIAL_SCROLL_DELAY = 50;
export const AUTO_SCROLL_AT_BOTTOM_THRESHOLD = 10;
@@ -0,0 +1,3 @@
export const CONTEXT_KEY_MESSAGE_EDIT = 'chat-message-edit';
export const CONTEXT_KEY_CHAT_ACTIONS = 'chat-actions';
export const CONTEXT_KEY_CHAT_SETTINGS_DIALOG = 'chat-settings-dialog';
@@ -10,6 +10,7 @@ export * from './cache';
export * from './chat-form';
export * from './code-blocks';
export * from './code';
export * from './context-keys';
export * from './css-classes';
export * from './favicon';
export * from './floating-ui-constraints';
@@ -1 +1,2 @@
export const FORK_TREE_DEPTH_PADDING = 8;
export const SYSTEM_MESSAGE_PLACEHOLDER = 'System message';
@@ -1,4 +1,5 @@
import { getContext, setContext } from 'svelte';
import { CONTEXT_KEY_CHAT_ACTIONS } from '$lib/constants';
export interface ChatActionsContext {
copy: (message: DatabaseMessage) => void;
@@ -21,9 +22,13 @@ export interface ChatActionsContext {
) => void;
regenerateWithBranching: (message: DatabaseMessage, modelOverride?: string) => void;
continueAssistantMessage: (message: DatabaseMessage) => void;
forkConversation: (
message: DatabaseMessage,
options: { name: string; includeAttachments: boolean }
) => void;
}
const CHAT_ACTIONS_KEY = Symbol.for('chat-actions');
const CHAT_ACTIONS_KEY = Symbol.for(CONTEXT_KEY_CHAT_ACTIONS);
export function setChatActionsContext(ctx: ChatActionsContext): ChatActionsContext {
return setContext(CHAT_ACTIONS_KEY, ctx);
@@ -0,0 +1,19 @@
import { getContext, setContext } from 'svelte';
import type { SettingsSectionTitle } from '$lib/constants';
import { CONTEXT_KEY_CHAT_SETTINGS_DIALOG } from '$lib/constants';
export interface ChatSettingsDialogContext {
open: (initialSection?: SettingsSectionTitle) => void;
}
const CHAT_SETTINGS_DIALOG_KEY = Symbol.for(CONTEXT_KEY_CHAT_SETTINGS_DIALOG);
export function setChatSettingsDialogContext(
ctx: ChatSettingsDialogContext
): ChatSettingsDialogContext {
return setContext(CHAT_SETTINGS_DIALOG_KEY, ctx);
}
export function getChatSettingsDialogContext(): ChatSettingsDialogContext {
return getContext(CHAT_SETTINGS_DIALOG_KEY);
}
@@ -11,3 +11,9 @@ export {
setChatActionsContext,
type ChatActionsContext
} from './chat-actions.context';
export {
getChatSettingsDialogContext,
setChatSettingsDialogContext,
type ChatSettingsDialogContext
} from './chat-settings-dialog.context';
@@ -1,4 +1,5 @@
import { getContext, setContext } from 'svelte';
import { CONTEXT_KEY_MESSAGE_EDIT } from '$lib/constants';
export interface MessageEditState {
readonly isEditing: boolean;
@@ -22,7 +23,7 @@ export interface MessageEditActions {
export type MessageEditContext = MessageEditState & MessageEditActions;
const MESSAGE_EDIT_KEY = Symbol.for('chat-message-edit');
const MESSAGE_EDIT_KEY = Symbol.for(CONTEXT_KEY_MESSAGE_EDIT);
/**
* Sets the message edit context. Call this in the parent component (ChatMessage.svelte).
@@ -1,8 +1,8 @@
import { AUTO_SCROLL_AT_BOTTOM_THRESHOLD, AUTO_SCROLL_INTERVAL } from '$lib/constants';
export interface AutoScrollOptions {
/** Whether auto-scroll is disabled globally (e.g., from settings) */
disabled?: boolean;
isColumnReverse?: boolean;
}
/**
@@ -12,6 +12,7 @@ export interface AutoScrollOptions {
* - Auto-scrolls to bottom during streaming/loading
* - Stops auto-scroll when user manually scrolls up
* - Resumes auto-scroll when user scrolls back to bottom
* - Supports both normal and column-reverse scroll containers
*/
export class AutoScrollController {
private _autoScrollEnabled = $state(true);
@@ -21,9 +22,14 @@ export class AutoScrollController {
private _scrollTimeout: ReturnType<typeof setTimeout> | undefined;
private _container: HTMLElement | undefined;
private _disabled: boolean;
private _isColumnReverse: boolean;
private _mutationObserver: MutationObserver | null = null;
private _rafPending = false;
private _observerEnabled = false;
constructor(options: AutoScrollOptions = {}) {
this._disabled = options.disabled ?? false;
this._isColumnReverse = options.isColumnReverse ?? false;
}
get autoScrollEnabled(): boolean {
@@ -38,7 +44,12 @@ export class AutoScrollController {
* Binds the controller to a scrollable container element.
*/
setContainer(container: HTMLElement | undefined): void {
this._doStopObserving();
this._container = container;
if (this._observerEnabled && container && !this._disabled) {
this._doStartObserving();
}
}
/**
@@ -49,6 +60,9 @@ export class AutoScrollController {
if (disabled) {
this._autoScrollEnabled = false;
this.stopInterval();
this._doStopObserving();
} else if (this._observerEnabled && this._container && !this._mutationObserver) {
this._doStartObserving();
}
}
@@ -59,10 +73,23 @@ export class AutoScrollController {
if (this._disabled || !this._container) return;
const { scrollTop, scrollHeight, clientHeight } = this._container;
const distanceFromBottom = scrollHeight - scrollTop - clientHeight;
let distanceFromBottom: number;
let isScrollingUp: boolean;
if (this._isColumnReverse) {
// column-reverse: scrollTop=0 at bottom, negative when scrolled up
distanceFromBottom = Math.abs(scrollTop);
isScrollingUp = scrollTop < this._lastScrollTop;
} else {
// normal: scrollTop=0 at top, increases when scrolled down
distanceFromBottom = scrollHeight - clientHeight - scrollTop;
isScrollingUp = scrollTop < this._lastScrollTop;
}
const isAtBottom = distanceFromBottom < AUTO_SCROLL_AT_BOTTOM_THRESHOLD;
if (scrollTop < this._lastScrollTop && !isAtBottom) {
if (isScrollingUp && !isAtBottom) {
this._userScrolledUp = true;
this._autoScrollEnabled = false;
} else if (isAtBottom && this._userScrolledUp) {
@@ -90,10 +117,12 @@ export class AutoScrollController {
scrollToBottom(behavior: ScrollBehavior = 'smooth'): void {
if (this._disabled || !this._container) return;
this._container.scrollTo({
top: this._container.scrollHeight,
behavior
});
if (this._isColumnReverse) {
// column-reverse: scrollTop=0 is the bottom
this._container.scrollTo({ top: 0, behavior });
} else {
this._container.scrollTo({ top: this._container.scrollHeight, behavior });
}
}
/**
@@ -150,11 +179,69 @@ export class AutoScrollController {
*/
destroy(): void {
this.stopInterval();
this._doStopObserving();
if (this._scrollTimeout) {
clearTimeout(this._scrollTimeout);
this._scrollTimeout = undefined;
}
}
/**
* Starts a MutationObserver on the container that auto-scrolls to bottom
* on content changes. More responsive than interval-based polling.
*/
startObserving(): void {
this._observerEnabled = true;
if (this._container && !this._disabled && !this._mutationObserver) {
this._doStartObserving();
}
}
/**
* Stops the MutationObserver.
*/
stopObserving(): void {
this._observerEnabled = false;
this._doStopObserving();
}
private _doStartObserving(): void {
if (!this._container || this._mutationObserver) return;
const isReverse = this._isColumnReverse;
this._mutationObserver = new MutationObserver(() => {
if (!this._autoScrollEnabled || this._rafPending) return;
this._rafPending = true;
requestAnimationFrame(() => {
this._rafPending = false;
if (this._autoScrollEnabled && this._container) {
if (isReverse) {
// column-reverse: scrollTop=0 is the bottom
this._container.scrollTop = 0;
} else {
this._container.scrollTop = this._container.scrollHeight;
}
}
});
});
this._mutationObserver.observe(this._container, {
childList: true,
subtree: true,
characterData: true
});
}
private _doStopObserving(): void {
if (this._mutationObserver) {
this._mutationObserver.disconnect();
this._mutationObserver = null;
}
this._rafPending = false;
}
}
/**
@@ -1,5 +1,6 @@
import Dexie, { type EntityTable } from 'dexie';
import { findDescendantMessages, uuid } from '$lib/utils';
import { findDescendantMessages, uuid, filterByLeafNodeId } from '$lib/utils';
import type { McpServerOverride } from '$lib/types/database';
class LlamacppDatabase extends Dexie {
conversations!: EntityTable<DatabaseConversation, string>;
@@ -173,8 +174,47 @@ export class DatabaseService {
*
* @param id - Conversation ID
*/
static async deleteConversation(id: string): Promise<void> {
static async deleteConversation(
id: string,
options?: { deleteWithForks?: boolean }
): Promise<void> {
await db.transaction('rw', [db.conversations, db.messages], async () => {
if (options?.deleteWithForks) {
// Recursively collect all descendant IDs
const idsToDelete: string[] = [];
const queue = [id];
while (queue.length > 0) {
const parentId = queue.pop()!;
const children = await db.conversations
.filter((c) => c.forkedFromConversationId === parentId)
.toArray();
for (const child of children) {
idsToDelete.push(child.id);
queue.push(child.id);
}
}
for (const forkId of idsToDelete) {
await db.conversations.delete(forkId);
await db.messages.where('convId').equals(forkId).delete();
}
} else {
// Reparent direct children to deleted conv's parent
const conv = await db.conversations.get(id);
const newParent = conv?.forkedFromConversationId;
const directChildren = await db.conversations
.filter((c) => c.forkedFromConversationId === id)
.toArray();
for (const child of directChildren) {
await db.conversations.update(child.id, {
forkedFromConversationId: newParent ?? undefined
});
}
}
await db.conversations.delete(id);
await db.messages.where('convId').equals(id).delete();
});
@@ -364,4 +404,88 @@ export class DatabaseService {
return { imported: importedCount, skipped: skippedCount };
});
}
/**
*
*
* Forking
*
*
*/
/**
* Forks a conversation at a specific message, creating a new conversation
* containing all messages from the root up to (and including) the target message.
*
* @param sourceConvId - The source conversation ID
* @param atMessageId - The message ID to fork at (the new conversation ends here)
* @param options - Fork options (name and whether to include attachments)
* @returns The newly created conversation
*/
static async forkConversation(
sourceConvId: string,
atMessageId: string,
options: { name: string; includeAttachments: boolean }
): Promise<DatabaseConversation> {
return await db.transaction('rw', [db.conversations, db.messages], async () => {
const sourceConv = await db.conversations.get(sourceConvId);
if (!sourceConv) {
throw new Error(`Source conversation ${sourceConvId} not found`);
}
const allMessages = await db.messages.where('convId').equals(sourceConvId).toArray();
const pathMessages = filterByLeafNodeId(allMessages, atMessageId, true) as DatabaseMessage[];
if (pathMessages.length === 0) {
throw new Error(`Could not resolve message path to ${atMessageId}`);
}
const idMap = new Map<string, string>();
for (const msg of pathMessages) {
idMap.set(msg.id, uuid());
}
const newConvId = uuid();
const clonedMessages: DatabaseMessage[] = pathMessages.map((msg) => {
const newId = idMap.get(msg.id)!;
const newParent = msg.parent ? (idMap.get(msg.parent) ?? null) : null;
const newChildren = msg.children
.filter((childId: string) => idMap.has(childId))
.map((childId: string) => idMap.get(childId)!);
return {
...msg,
id: newId,
convId: newConvId,
parent: newParent,
children: newChildren,
extra: options.includeAttachments ? msg.extra : undefined
};
});
const lastClonedMessage = clonedMessages[clonedMessages.length - 1];
const newConv: DatabaseConversation = {
id: newConvId,
name: options.name,
lastModified: Date.now(),
currNode: lastClonedMessage.id,
forkedFromConversationId: sourceConvId,
mcpServerOverrides: sourceConv.mcpServerOverrides
? sourceConv.mcpServerOverrides.map((o: McpServerOverride) => ({
serverId: o.serverId,
enabled: o.enabled
}))
: undefined
};
await db.conversations.add(newConv);
for (const msg of clonedMessages) {
await db.messages.add(msg);
}
return newConv;
});
}
}
@@ -1265,35 +1265,53 @@ class ChatStore {
let result = this.getMessageByIdWithRole(messageId, MessageRole.USER);
if (!result) result = this.getMessageByIdWithRole(messageId, MessageRole.SYSTEM);
if (!result) return;
const { message: msg } = result;
const { message: msg, index: idx } = result;
try {
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
const isFirstUserMessage =
msg.role === MessageRole.USER && rootMessage && msg.parent === rootMessage.id;
const parentId = msg.parent || rootMessage?.id;
if (!parentId) return;
const extrasToUse =
newExtras !== undefined
? JSON.parse(JSON.stringify(newExtras))
: msg.extra
? JSON.parse(JSON.stringify(msg.extra))
: undefined;
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
let messageIdForResponse: string;
if (msg.children.length === 0) {
// No responses after this message — update in place instead of branching
const updates: Partial<DatabaseMessage> = {
content: newContent,
toolCalls: msg.toolCalls || '',
children: [],
extra: extrasToUse,
model: msg.model
},
parentId
);
await conversationsStore.updateCurrentNode(newMessage.id);
timestamp: Date.now(),
extra: extrasToUse
};
await DatabaseService.updateMessage(msg.id, updates);
conversationsStore.updateMessageAtIndex(idx, updates);
messageIdForResponse = msg.id;
} else {
// Has children — create a new branch as sibling
const parentId = msg.parent || rootMessage?.id;
if (!parentId) return;
const newMessage = await DatabaseService.createMessageBranch(
{
convId: msg.convId,
type: msg.type,
timestamp: Date.now(),
role: msg.role,
content: newContent,
toolCalls: msg.toolCalls || '',
children: [],
extra: extrasToUse,
model: msg.model
},
parentId
);
await conversationsStore.updateCurrentNode(newMessage.id);
messageIdForResponse = newMessage.id;
}
conversationsStore.updateConversationTimestamp();
if (isFirstUserMessage && newContent.trim())
await conversationsStore.updateConversationTitleWithConfirmation(
@@ -1301,7 +1319,8 @@ class ChatStore {
newContent.trim()
);
await conversationsStore.refreshActiveMessages();
if (msg.role === MessageRole.USER) await this.generateResponseForMessage(newMessage.id);
if (msg.role === MessageRole.USER)
await this.generateResponseForMessage(messageIdForResponse);
} catch (error) {
console.error('Failed to edit message with branching:', error);
}
@@ -39,6 +39,12 @@ import {
MULTIPLE_UNDERSCORE_REGEX,
MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY
} from '$lib/constants';
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
export interface ConversationTreeItem {
conversation: DatabaseConversation;
depth: number;
}
class ConversationsStore {
/**
@@ -300,15 +306,45 @@ class ConversationsStore {
* Deletes a conversation and all its messages
* @param convId - The conversation ID to delete
*/
async deleteConversation(convId: string): Promise<void> {
async deleteConversation(convId: string, options?: { deleteWithForks?: boolean }): Promise<void> {
try {
await DatabaseService.deleteConversation(convId);
await DatabaseService.deleteConversation(convId, options);
this.conversations = this.conversations.filter((c) => c.id !== convId);
if (options?.deleteWithForks) {
// Collect all descendants recursively
const idsToRemove = new SvelteSet([convId]);
const queue = [convId];
while (queue.length > 0) {
const parentId = queue.pop()!;
for (const c of this.conversations) {
if (c.forkedFromConversationId === parentId && !idsToRemove.has(c.id)) {
idsToRemove.add(c.id);
queue.push(c.id);
}
}
}
this.conversations = this.conversations.filter((c) => !idsToRemove.has(c.id));
if (this.activeConversation?.id === convId) {
this.clearActiveConversation();
await goto(`?new_chat=true#/`);
if (this.activeConversation && idsToRemove.has(this.activeConversation.id)) {
this.clearActiveConversation();
await goto(`?new_chat=true#/`);
}
} else {
// Reparent direct children to deleted conv's parent (or promote to top-level)
const deletedConv = this.conversations.find((c) => c.id === convId);
const newParent = deletedConv?.forkedFromConversationId;
this.conversations = this.conversations
.filter((c) => c.id !== convId)
.map((c) =>
c.forkedFromConversationId === convId
? { ...c, forkedFromConversationId: newParent }
: c
);
if (this.activeConversation?.id === convId) {
this.clearActiveConversation();
await goto(`?new_chat=true#/`);
}
}
} catch (error) {
console.error('Failed to delete conversation:', error);
@@ -658,6 +694,42 @@ class ConversationsStore {
this.saveMcpDefaults();
}
/**
* Forks a conversation at a specific message, creating a new conversation
* containing messages from root up to the target message, then navigates to it.
*
* @param messageId - The message ID to fork at
* @param options - Fork options (name and whether to include attachments)
* @returns The new conversation ID, or null if fork failed
*/
async forkConversation(
messageId: string,
options: { name: string; includeAttachments: boolean }
): Promise<string | null> {
if (!this.activeConversation) return null;
try {
const newConv = await DatabaseService.forkConversation(
this.activeConversation.id,
messageId,
options
);
this.conversations = [newConv, ...this.conversations];
await goto(`#/chat/${newConv.id}`);
toast.success('Conversation forked');
return newConv.id;
} catch (error) {
console.error('Failed to fork conversation:', error);
toast.error('Failed to fork conversation');
return null;
}
}
/**
*
*
@@ -830,3 +902,53 @@ export const conversations = () => conversationsStore.conversations;
export const activeConversation = () => conversationsStore.activeConversation;
export const activeMessages = () => conversationsStore.activeMessages;
export const isConversationsInitialized = () => conversationsStore.isInitialized;
/**
* Builds a flat tree of conversations with depth levels for nested forks.
* Accepts a pre-filtered list so search filtering stays in the component.
*/
export function buildConversationTree(convs: DatabaseConversation[]): ConversationTreeItem[] {
const childrenByParent = new SvelteMap<string, DatabaseConversation[]>();
const forkIds = new SvelteSet<string>();
for (const conv of convs) {
if (conv.forkedFromConversationId) {
forkIds.add(conv.id);
const siblings = childrenByParent.get(conv.forkedFromConversationId) || [];
siblings.push(conv);
childrenByParent.set(conv.forkedFromConversationId, siblings);
}
}
const result: ConversationTreeItem[] = [];
const visited = new SvelteSet<string>();
function walk(conv: DatabaseConversation, depth: number) {
visited.add(conv.id);
result.push({ conversation: conv, depth });
const children = childrenByParent.get(conv.id);
if (children) {
children.sort((a, b) => b.lastModified - a.lastModified);
for (const child of children) {
walk(child, depth + 1);
}
}
}
const roots = convs.filter((c) => !forkIds.has(c.id));
for (const root of roots) {
walk(root, 0);
}
for (const conv of convs) {
if (!visited.has(conv.id)) {
walk(conv, 1);
}
}
return result;
}
+1
View File
@@ -12,6 +12,7 @@ export interface DatabaseConversation {
lastModified: number;
name: string;
mcpServerOverrides?: McpServerOverride[];
forkedFromConversationId?: string;
}
export interface DatabaseMessageExtraAudioFile {
+23 -1
View File
@@ -4,7 +4,11 @@
import { browser } from '$app/environment';
import { page } from '$app/state';
import { untrack } from 'svelte';
import { ChatSidebar, DialogConversationTitleUpdate } from '$lib/components/app';
import {
ChatSidebar,
DialogConversationTitleUpdate,
DialogChatSettings
} from '$lib/components/app';
import { isLoading } from '$lib/stores/chat.svelte';
import { conversationsStore, activeMessages } from '$lib/stores/conversations.svelte';
import * as Sidebar from '$lib/components/ui/sidebar/index.js';
@@ -17,8 +21,10 @@
import { modelsStore } from '$lib/stores/models.svelte';
import { mcpStore } from '$lib/stores/mcp.svelte';
import { TOOLTIP_DELAY_DURATION } from '$lib/constants';
import type { SettingsSectionTitle } from '$lib/constants';
import { KeyboardKey } from '$lib/enums';
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
import { setChatSettingsDialogContext } from '$lib/contexts';
let { children } = $props();
@@ -42,6 +48,16 @@
let titleUpdateNewTitle = $state('');
let titleUpdateResolve: ((value: boolean) => void) | null = null;
let chatSettingsDialogOpen = $state(false);
let chatSettingsDialogInitialSection = $state<SettingsSectionTitle | undefined>(undefined);
setChatSettingsDialogContext({
open: (initialSection?: SettingsSectionTitle) => {
chatSettingsDialogInitialSection = initialSection;
chatSettingsDialogOpen = true;
}
});
// Global keyboard shortcuts
function handleKeydown(event: KeyboardEvent) {
const isCtrlOrCmd = event.ctrlKey || event.metaKey;
@@ -213,6 +229,12 @@
<Toaster richColors />
<DialogChatSettings
open={chatSettingsDialogOpen}
onOpenChange={(open) => (chatSettingsDialogOpen = open)}
initialSection={chatSettingsDialogInitialSection}
/>
<DialogConversationTitleUpdate
bind:open={titleUpdateDialogOpen}
currentTitle={titleUpdateCurrentTitle}
@@ -73,7 +73,7 @@
conversationsStore.conversations = mockConversations;
}, 0));
const searchTrigger = screen.getByText('Search conversations');
const searchTrigger = screen.getByText('Search');
userEvent.click(searchTrigger);
}}
>
+104 -41
View File
@@ -467,10 +467,6 @@ bool set_socket_opt_impl(socket_t sock, int level, int optname,
optlen) == 0;
}
bool set_socket_opt(socket_t sock, int level, int optname, int optval) {
return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval));
}
bool set_socket_opt_time(socket_t sock, int level, int optname,
time_t sec, time_t usec) {
#ifdef _WIN32
@@ -2218,7 +2214,7 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port,
#ifdef _WIN32
// Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so
// remove the option.
detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0);
set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0);
#endif
bool dummy;
@@ -4373,6 +4369,7 @@ make_multipart_content_provider(const UploadFormDataItems &items,
struct MultipartState {
std::vector<std::string> owned;
std::vector<MultipartSegment> segs;
std::vector<char> buf = std::vector<char>(CPPHTTPLIB_SEND_BUFSIZ);
};
auto state = std::make_shared<MultipartState>();
state->owned = std::move(owned);
@@ -4381,19 +4378,49 @@ make_multipart_content_provider(const UploadFormDataItems &items,
state->segs = std::move(segs);
return [state](size_t offset, size_t length, DataSink &sink) -> bool {
// Buffer multiple small segments into fewer, larger writes to avoid
// excessive TCP packets when there are many form data items (#2410)
auto &buf = state->buf;
auto buf_size = buf.size();
size_t buf_len = 0;
size_t remaining = length;
// Find the first segment containing 'offset'
size_t pos = 0;
for (const auto &seg : state->segs) {
// Loop invariant: pos <= offset (proven by advancing pos only when
// offset - pos >= seg.size, i.e., the segment doesn't contain offset)
if (seg.size > 0 && offset - pos < seg.size) {
size_t seg_offset = offset - pos;
size_t available = seg.size - seg_offset;
size_t to_write = (std::min)(available, length);
return sink.write(seg.data + seg_offset, to_write);
}
size_t seg_idx = 0;
for (; seg_idx < state->segs.size(); seg_idx++) {
const auto &seg = state->segs[seg_idx];
if (seg.size > 0 && offset - pos < seg.size) { break; }
pos += seg.size;
}
return true; // past end (shouldn't be reached when content_length is exact)
size_t seg_offset = (seg_idx < state->segs.size()) ? offset - pos : 0;
for (; seg_idx < state->segs.size() && remaining > 0; seg_idx++) {
const auto &seg = state->segs[seg_idx];
size_t available = seg.size - seg_offset;
size_t to_copy = (std::min)(available, remaining);
const char *src = seg.data + seg_offset;
seg_offset = 0; // only the first segment has a non-zero offset
while (to_copy > 0) {
size_t space = buf_size - buf_len;
size_t chunk = (std::min)(to_copy, space);
std::memcpy(buf.data() + buf_len, src, chunk);
buf_len += chunk;
src += chunk;
to_copy -= chunk;
remaining -= chunk;
if (buf_len == buf_size) {
if (!sink.write(buf.data(), buf_len)) { return false; }
buf_len = 0;
}
}
}
if (buf_len > 0) { return sink.write(buf.data(), buf_len); }
return true;
};
}
@@ -5264,13 +5291,18 @@ bool setup_client_tls_session(const std::string &host, tls::ctx_t &ctx,
*/
void default_socket_options(socket_t sock) {
detail::set_socket_opt(sock, SOL_SOCKET,
set_socket_opt(sock, SOL_SOCKET,
#ifdef SO_REUSEPORT
SO_REUSEPORT,
SO_REUSEPORT,
#else
SO_REUSEADDR,
SO_REUSEADDR,
#endif
1);
1);
}
bool set_socket_opt(socket_t sock, int level, int optname, int optval) {
return detail::set_socket_opt_impl(sock, level, optname, &optval,
sizeof(optval));
}
std::string get_bearer_token_auth(const Request &req) {
@@ -7418,6 +7450,8 @@ bool Server::read_content_core(
return false;
}
req.body_consumed_ = true;
if (req.is_multipart_form_data()) {
if (!multipart_form_data_parser.is_valid()) {
res.status = StatusCode::BadRequest_400;
@@ -7688,9 +7722,7 @@ bool Server::listen_internal() {
detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO,
write_timeout_sec_, write_timeout_usec_);
if (tcp_nodelay_) {
detail::set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1);
}
if (tcp_nodelay_) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); }
if (!task_queue->enqueue(
[this, sock]() { process_and_close_socket(sock); })) {
@@ -8036,8 +8068,19 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
return write_response(strm, close_connection, req, res);
}
// RFC 9112 §6.3: Reject requests with both a non-zero Content-Length and
// any Transfer-Encoding to prevent request smuggling. Content-Length: 0 is
// tolerated for compatibility with existing clients.
if (req.get_header_value_u64("Content-Length") > 0 &&
req.has_header("Transfer-Encoding")) {
connection_closed = true;
res.status = StatusCode::BadRequest_400;
return write_response(strm, close_connection, req, res);
}
// Check if the request URI doesn't exceed the limit
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
connection_closed = true;
res.status = StatusCode::UriTooLong_414;
output_error_log(Error::ExceedUriMaxLength, &req);
return write_response(strm, close_connection, req, res);
@@ -8066,6 +8109,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
if (req.has_header("Accept")) {
const auto &accept_header = req.get_header_value("Accept");
if (!detail::parse_accept_header(accept_header, req.accept_content_types)) {
connection_closed = true;
res.status = StatusCode::BadRequest_400;
output_error_log(Error::HTTPParsing, &req);
return write_response(strm, close_connection, req, res);
@@ -8075,6 +8119,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
if (req.has_header("Range")) {
const auto &range_header_value = req.get_header_value("Range");
if (!detail::parse_range_header(range_header_value, req.ranges)) {
connection_closed = true;
res.status = StatusCode::RangeNotSatisfiable_416;
output_error_log(Error::InvalidRangeHeader, &req);
return write_response(strm, close_connection, req, res);
@@ -8202,6 +8247,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
}
}
#endif
auto ret = false;
if (routed) {
if (res.status == -1) {
res.status = req.ranges.empty() ? StatusCode::OK_200
@@ -8209,6 +8255,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
}
// Serve file content by using a content provider
auto file_open_error = false;
if (!res.file_content_path_.empty()) {
const auto &path = res.file_content_path_;
auto mm = std::make_shared<detail::mmap>(path.c_str());
@@ -8218,37 +8265,53 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
res.content_provider_ = nullptr;
res.status = StatusCode::NotFound_404;
output_error_log(Error::OpenFile, &req);
return write_response(strm, close_connection, req, res);
}
file_open_error = true;
} else {
auto content_type = res.file_content_content_type_;
if (content_type.empty()) {
content_type = detail::find_content_type(
path, file_extension_and_mimetype_map_, default_file_mimetype_);
}
auto content_type = res.file_content_content_type_;
if (content_type.empty()) {
content_type = detail::find_content_type(
path, file_extension_and_mimetype_map_, default_file_mimetype_);
res.set_content_provider(
mm->size(), content_type,
[mm](size_t offset, size_t length, DataSink &sink) -> bool {
sink.write(mm->data() + offset, length);
return true;
});
}
res.set_content_provider(
mm->size(), content_type,
[mm](size_t offset, size_t length, DataSink &sink) -> bool {
sink.write(mm->data() + offset, length);
return true;
});
}
if (detail::range_error(req, res)) {
if (file_open_error) {
ret = write_response(strm, close_connection, req, res);
} else if (detail::range_error(req, res)) {
res.body.clear();
res.content_length_ = 0;
res.content_provider_ = nullptr;
res.status = StatusCode::RangeNotSatisfiable_416;
return write_response(strm, close_connection, req, res);
ret = write_response(strm, close_connection, req, res);
} else {
ret = write_response_with_content(strm, close_connection, req, res);
}
return write_response_with_content(strm, close_connection, req, res);
} else {
if (res.status == -1) { res.status = StatusCode::NotFound_404; }
return write_response(strm, close_connection, req, res);
ret = write_response(strm, close_connection, req, res);
}
// Drain any unconsumed request body to prevent request smuggling on
// keep-alive connections.
if (!req.body_consumed_ && detail::expect_content(req)) {
int drain_status = 200; // required by read_content signature
if (!detail::read_content(
strm, req, payload_max_length_, drain_status, nullptr,
[](const char *, size_t, size_t, size_t) { return true; }, false)) {
// Body exceeds payload limit or read error — close the connection
// to prevent leftover bytes from being misinterpreted.
connection_closed = true;
}
}
return ret;
}
bool Server::is_valid() const { return true; }
+12 -2
View File
@@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.39.0"
#define CPPHTTPLIB_VERSION_NUM "0x002700"
#define CPPHTTPLIB_VERSION "0.40.0"
#define CPPHTTPLIB_VERSION_NUM "0x002800"
#ifdef _WIN32
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
@@ -1266,6 +1266,7 @@ struct Request {
bool is_multipart_form_data() const;
// private members...
bool body_consumed_ = false;
size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT;
size_t content_length_ = 0;
ContentProvider content_provider_;
@@ -1475,6 +1476,8 @@ using SocketOptions = std::function<void(socket_t sock)>;
void default_socket_options(socket_t sock);
bool set_socket_opt(socket_t sock, int level, int optname, int optval);
const char *status_message(int status);
std::string to_string(Error error);
@@ -1564,6 +1567,13 @@ ssize_t write_headers(Stream &strm, const Headers &headers);
bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec,
time_t usec);
size_t get_multipart_content_length(const UploadFormDataItems &items,
const std::string &boundary);
ContentProvider
make_multipart_content_provider(const UploadFormDataItems &items,
const std::string &boundary);
} // namespace detail
class Server {