mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 18:17:42 +02:00
Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ead417f01c | |||
| 64ac9ab66a | |||
| cad2d3884c | |||
| 389c7d4955 | |||
| 278521c33a | |||
| e2eb39e81c | |||
| abf9a62161 | |||
| 7c203670f8 | |||
| ec16a072f0 | |||
| f5d1c4179f | |||
| 2405d59cb6 | |||
| afe65aa282 | |||
| 65097181e4 | |||
| 98ae0a0d36 | |||
| 3a14a542f5 | |||
| 968189729f | |||
| e397d3885c | |||
| e6f2ec01ff | |||
| edfb440a2f | |||
| 3d66da1809 | |||
| 82b703f8bc | |||
| 51a84efc53 | |||
| b0f0dd3e51 | |||
| 0eb4764182 | |||
| 1f5d15e665 | |||
| c46758d28f | |||
| bf934f28db | |||
| 5c1a7b8355 | |||
| 59d840209a |
@@ -1,11 +1,13 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
ARG TARGETARCH
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential git cmake libssl-dev
|
||||
apt-get install -y gcc-14 g++-14 build-essential git cmake libssl-dev
|
||||
|
||||
ENV CC=gcc-14 CXX=g++-14
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -55,8 +57,9 @@ RUN apt-get update \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
&& pip install --upgrade pip setuptools wheel \
|
||||
&& pip install -r requirements.txt \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -33,6 +33,23 @@ RUN mkdir -p /app/full \
|
||||
|
||||
FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS base
|
||||
|
||||
ARG IGC_VERSION=v2.30.1
|
||||
ARG IGC_VERSION_FULL=2_2.30.1+20950
|
||||
ARG COMPUTE_RUNTIME_VERSION=26.09.37435.1
|
||||
ARG COMPUTE_RUNTIME_VERSION_FULL=26.09.37435.1-0
|
||||
ARG IGDGMM_VERSION=22.9.0
|
||||
RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
|
||||
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-core-${IGC_VERSION_FULL}_amd64.deb \
|
||||
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-opencl-${IGC_VERSION_FULL}_amd64.deb \
|
||||
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/intel-ocloc-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \
|
||||
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/intel-ocloc_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \
|
||||
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/intel-opencl-icd-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \
|
||||
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/intel-opencl-icd_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \
|
||||
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/libigdgmm12_${IGDGMM_VERSION}_amd64.deb \
|
||||
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/libze-intel-gpu1-dbgsym_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.ddeb \
|
||||
&& wget https://github.com/intel/compute-runtime/releases/download/$COMPUTE_RUNTIME_VERSION/libze-intel-gpu1_${COMPUTE_RUNTIME_VERSION_FULL}_amd64.deb \
|
||||
&& dpkg --install *.deb
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt autoremove -y \
|
||||
|
||||
@@ -36,18 +36,16 @@ jobs:
|
||||
matrix:
|
||||
config:
|
||||
# Multi-stage build
|
||||
# Note: the arm64 images are failing, which prevents the amd64 images from being built
|
||||
# https://github.com/ggml-org/llama.cpp/issues/11888
|
||||
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "cuda cuda12", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04", cuda_version: "12.4.0", ubuntu_version: "22.04" }
|
||||
- { tag: "cuda13", dockerfile: ".devops/cuda-new.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04", cuda_version: "13.1.0", ubuntu_version: "24.04" }
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04-s390x" }
|
||||
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "openvino", dockerfile: ".devops/openvino.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/arm64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "cuda cuda12", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "12.4.0", ubuntu_version: "22.04" }
|
||||
- { tag: "cuda13", dockerfile: ".devops/cuda-new.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "13.1.0", ubuntu_version: "24.04" }
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04-s390x" }
|
||||
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "openvino", dockerfile: ".devops/openvino.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v6
|
||||
@@ -58,7 +56,7 @@ jobs:
|
||||
if: ${{ matrix.config.tag != 's390x' }}
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3
|
||||
with:
|
||||
image: tonistiigi/binfmt:qemu-v7.0.0-28
|
||||
image: tonistiigi/binfmt:qemu-v10.2.1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
pip-install: -r requirements/requirements-all.txt ty==0.0.24
|
||||
pip-install: -r requirements/requirements-all.txt ty==0.0.26
|
||||
# - name: Type-check with Pyright
|
||||
# uses: jakebailey/pyright-action@v2
|
||||
# with:
|
||||
|
||||
@@ -2807,6 +2807,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.port = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT"));
|
||||
add_opt(common_arg(
|
||||
{"--reuse-port"},
|
||||
string_format("allow multiple sockets to bind to the same port (default: %s)", params.reuse_port ? "enabled" : "disabled"),
|
||||
[](common_params & params) {
|
||||
params.reuse_port = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_REUSE_PORT"));
|
||||
add_opt(common_arg(
|
||||
{"--path"}, "PATH",
|
||||
string_format("path to serve static files from (default: %s)", params.public_path.c_str()),
|
||||
|
||||
@@ -65,7 +65,7 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
@@ -221,7 +221,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & func = tool.at("function");
|
||||
std::string name = func.at("name");
|
||||
const auto & schema = func.at("parameters");
|
||||
const auto & schema = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
@@ -282,19 +282,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & func = tool.at("function");
|
||||
std::string name = func.at("name");
|
||||
const auto & params = func.at("parameters");
|
||||
|
||||
if (!params.contains("properties") || !params.at("properties").is_object()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto & properties = params.at("properties");
|
||||
const auto & func = tool.at("function");
|
||||
std::string name = func.at("name");
|
||||
const auto & params = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
const auto & properties = params.contains("properties") ? params.at("properties") : json::object();
|
||||
std::set<std::string> required;
|
||||
if (params.contains("required") && params.at("required").is_array()) {
|
||||
params.at("required").get_to(required);
|
||||
}
|
||||
|
||||
// Build parser for each argument, separating required and optional
|
||||
std::vector<common_peg_parser> required_parsers;
|
||||
@@ -311,17 +303,18 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
}
|
||||
}
|
||||
|
||||
auto arg = p.tool_arg(
|
||||
p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) +
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(type == "string" ? p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, true)) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.space()) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
auto arg =
|
||||
p.tool_arg(p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) +
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(type == "string" ?
|
||||
p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema, true)) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.space()) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
|
||||
@@ -287,7 +287,7 @@ void analyze_reasoning::compare_reasoning_presence() {
|
||||
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())) + p.rest());
|
||||
});
|
||||
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
|
||||
return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
|
||||
});
|
||||
// try the more aggressive parse first, if it fails, fall back to the delimiter one
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
@@ -297,7 +297,7 @@ void analyze_reasoning::compare_reasoning_presence() {
|
||||
if (result.result.success()) {
|
||||
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
start = trim_whitespace(result.tags["pre"]);
|
||||
start = trim_leading_whitespace(result.tags["pre"]);
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else if (!result.tags["post"].empty()) {
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
@@ -333,7 +333,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (left_trimmed.empty() && !diff.right.empty()) {
|
||||
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
|
||||
if (start.empty()) {
|
||||
start = right_trimmed;
|
||||
start = trim_leading_whitespace(diff.right);
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -344,7 +344,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
|
||||
start = seg[seg.size() - 2].value;
|
||||
}
|
||||
end = left_trimmed;
|
||||
end = trim_trailing_whitespace(diff.left);
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -363,15 +363,23 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
size_t len = std::min(base.size(), anchor_len);
|
||||
std::string anchor = base.substr(base.size() - len);
|
||||
auto pos = extended.rfind(anchor);
|
||||
if (pos == std::string::npos || pos + len >= extended.size()) continue;
|
||||
if (pos == std::string::npos || pos + len >= extended.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string extra = trim_whitespace(extended.substr(pos + len));
|
||||
if (extra.empty()) continue;
|
||||
if (extra.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto seg = prune_whitespace_segments(segmentize_markers(extra));
|
||||
if (seg.size() == 2 && seg[0].type == segment_type::MARKER && seg[1].type == segment_type::MARKER) {
|
||||
if (start.empty()) start = seg[0].value;
|
||||
if (end.empty()) end = seg[1].value;
|
||||
if (start.empty()) {
|
||||
start = seg[0].value;
|
||||
}
|
||||
if (end.empty()) {
|
||||
end = seg[1].value;
|
||||
}
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
break;
|
||||
}
|
||||
@@ -423,7 +431,7 @@ void analyze_reasoning::compare_reasoning_scope() {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Detected TOOLS_ONLY reasoning mode\n" ANSI_RESET, __func__);
|
||||
|
||||
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space()));
|
||||
return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space()));
|
||||
});
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
@@ -516,7 +524,7 @@ analyze_content::analyze_content(const common_chat_template & tmpl, const analyz
|
||||
// Take the more promising diff
|
||||
std::string pure_content = rdiff.length() > diff_tools.left.length() ? rdiff : diff_tools.left;
|
||||
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.marker()) + p.space() + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
|
||||
return p.tag("pre", p.marker() + p.space()) + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
|
||||
});
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(pure_content);
|
||||
start = result.tags["pre"];
|
||||
|
||||
+8
-1
@@ -971,6 +971,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto start = p.rule("start", p.literal("<|start|>assistant"));
|
||||
@@ -979,7 +980,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
auto channel = p.literal("<|channel|>") + (p.literal("commentary") | p.literal("analysis"));
|
||||
auto constrain_type = p.chars("[A-Za-z0-9_-]", 1, -1);
|
||||
|
||||
auto analysis = p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end);
|
||||
if (extract_reasoning) {
|
||||
p.rule("analysis", p.literal("<|channel|>analysis<|message|>") + p.reasoning(content) + end);
|
||||
} else {
|
||||
p.rule("analysis", p.content(p.literal("<|channel|>analysis<|message|>") + content + end));
|
||||
}
|
||||
|
||||
auto analysis = p.ref("analysis");
|
||||
auto preamble = p.rule("preamble", p.literal("<|channel|>commentary<|message|>") + p.content(content) + end);
|
||||
auto final_msg = p.rule("final", p.literal("<|channel|>final<|message|>") + p.content(content));
|
||||
auto any = p.rule("any", preamble | analysis);
|
||||
|
||||
@@ -656,6 +656,97 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline bool glob_class_match(const char c, const char * pattern, const char * class_end) {
|
||||
const char * class_start = pattern;
|
||||
bool negated = false;
|
||||
|
||||
if (*class_start == '!') {
|
||||
negated = true;
|
||||
class_start++;
|
||||
}
|
||||
|
||||
// If first character after negation is ']' or '-', treat it as literal
|
||||
if (*class_start == ']' || *class_start == '-') {
|
||||
if (class_start < class_end && *class_start == c) {
|
||||
return !negated;
|
||||
}
|
||||
class_start++;
|
||||
}
|
||||
|
||||
bool matched = false;
|
||||
|
||||
while (class_start < class_end) {
|
||||
if (class_start + 2 < class_end && class_start[1] == '-' && class_start[2] != ']') {
|
||||
char start_char = *class_start;
|
||||
char end_char = class_start[2];
|
||||
if (c >= start_char && c <= end_char) {
|
||||
matched = true;
|
||||
break;
|
||||
}
|
||||
class_start += 3;
|
||||
} else {
|
||||
if (*class_start == c) {
|
||||
matched = true;
|
||||
break;
|
||||
}
|
||||
class_start++;
|
||||
}
|
||||
}
|
||||
|
||||
return negated ? !matched : matched;
|
||||
}
|
||||
|
||||
// simple glob: * matches non-/ chars, ** matches anything including /, [] matches character class
|
||||
static inline bool glob_match(const char * pattern, const char * str) {
|
||||
if (*pattern == '\0') {
|
||||
return *str == '\0';
|
||||
}
|
||||
if (pattern[0] == '*' && pattern[1] == '*') {
|
||||
const char * p = pattern + 2;
|
||||
if (glob_match(p, str)) return true;
|
||||
if (*str != '\0') return glob_match(pattern, str + 1);
|
||||
return false;
|
||||
}
|
||||
if (*pattern == '*') {
|
||||
const char * p = pattern + 1;
|
||||
for (; *str != '\0' && *str != '/'; str++) {
|
||||
if (glob_match(p, str)) return true;
|
||||
}
|
||||
return glob_match(p, str);
|
||||
}
|
||||
if (*pattern == '?' && *str != '\0' && *str != '/') {
|
||||
return glob_match(pattern + 1, str + 1);
|
||||
}
|
||||
if (*pattern == '[') {
|
||||
const char * class_end = pattern + 1;
|
||||
// If first character after '[' is ']' or '-', treat it as literal
|
||||
if (*class_end == ']' || *class_end == '-') {
|
||||
class_end++;
|
||||
}
|
||||
while (*class_end != '\0' && *class_end != ']') {
|
||||
class_end++;
|
||||
}
|
||||
if (*class_end == ']') {
|
||||
if (*str == '\0') return false;
|
||||
bool matched = glob_class_match(*str, pattern + 1, class_end);
|
||||
return matched && glob_match(class_end + 1, str + 1);
|
||||
} else {
|
||||
if (*str == '[') {
|
||||
return glob_match(pattern + 1, str + 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (*pattern == *str) {
|
||||
return glob_match(pattern + 1, str + 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool glob_match(const std::string & pattern, const std::string & str) {
|
||||
return glob_match(pattern.c_str(), str.c_str());
|
||||
}
|
||||
|
||||
//
|
||||
// Filesystem utils
|
||||
//
|
||||
|
||||
@@ -573,6 +573,7 @@ struct common_params {
|
||||
|
||||
// server params
|
||||
int32_t port = 8080; // server listens on this network port
|
||||
bool reuse_port = false; // allow multiple sockets to bind to the same port
|
||||
int32_t timeout_read = 600; // http read timeout in seconds
|
||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
@@ -793,6 +794,8 @@ std::string string_from(const std::vector<int> & values);
|
||||
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
|
||||
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
|
||||
|
||||
bool glob_match(const std::string & pattern, const std::string & str);
|
||||
|
||||
//
|
||||
// Filesystem utils
|
||||
//
|
||||
|
||||
@@ -539,6 +539,9 @@ private:
|
||||
statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr;
|
||||
return mk_stmt<slice_expression>(start_pos, std::move(start), std::move(stop), std::move(step));
|
||||
}
|
||||
if (slices.empty()) {
|
||||
return mk_stmt<blank_expression>(start_pos);
|
||||
}
|
||||
return std::move(slices[0]);
|
||||
}
|
||||
|
||||
|
||||
@@ -771,10 +771,15 @@ value member_expression::execute_impl(context & ctx) {
|
||||
}
|
||||
|
||||
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
|
||||
ensure_key_type_allowed(property);
|
||||
|
||||
value val = mk_val<value_undefined>("object_property");
|
||||
|
||||
if (property->is_undefined()) {
|
||||
JJ_DEBUG("%s", "Member expression property is undefined, returning undefined");
|
||||
return val;
|
||||
}
|
||||
|
||||
ensure_key_type_allowed(property);
|
||||
|
||||
if (is_val<value_undefined>(object)) {
|
||||
JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined");
|
||||
return val;
|
||||
|
||||
@@ -263,6 +263,14 @@ struct comment_statement : public statement {
|
||||
|
||||
// Expressions
|
||||
|
||||
// Represents an omitted expression in a computed member, e.g. `a[]`.
|
||||
struct blank_expression : public expression {
|
||||
std::string type() const override { return "BlankExpression"; }
|
||||
value execute_impl(context &) override {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
};
|
||||
|
||||
struct member_expression : public expression {
|
||||
statement_ptr object;
|
||||
statement_ptr property;
|
||||
|
||||
@@ -416,15 +416,30 @@ private:
|
||||
i++;
|
||||
} else if (c == '(') {
|
||||
i++;
|
||||
if (i < length) {
|
||||
if (sub_pattern[i] == '?') {
|
||||
if (i < length && sub_pattern[i] == '?') {
|
||||
if (i + 1 < length && sub_pattern[i + 1] == ':') {
|
||||
i += 2; // skip "?:" for non-capturing group, treat as regular group
|
||||
} else {
|
||||
// lookahead/lookbehind (?=, ?!, ?<=, ?<!) - not supported
|
||||
_warnings.push_back("Unsupported pattern syntax");
|
||||
// skip to matching ')' to avoid UB on empty seq
|
||||
int depth = 1;
|
||||
while (i < length && depth > 0) {
|
||||
if (sub_pattern[i] == '\\' && i + 1 < length) {
|
||||
i += 2; // skip escaped character
|
||||
} else {
|
||||
if (sub_pattern[i] == '(') depth++;
|
||||
else if (sub_pattern[i] == ')') depth--;
|
||||
i++;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
seq.emplace_back("(" + to_rule(transform()) + ")", false);
|
||||
} else if (c == ')') {
|
||||
i++;
|
||||
if (start > 0 && sub_pattern[start - 1] != '(') {
|
||||
if (start > 0 && sub_pattern[start - 1] != '(' && (start < 2 || sub_pattern[start - 2] != '?' || sub_pattern[start - 1] != ':')) {
|
||||
_errors.push_back("Unbalanced parentheses");
|
||||
}
|
||||
return join_seq();
|
||||
|
||||
+12
-11
@@ -115,9 +115,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
break;
|
||||
}
|
||||
case REASONING_BUDGET_FORCING:
|
||||
// force_pos is advanced in apply(), not here.
|
||||
// This ensures the first forced token isn't skipped when the sampler
|
||||
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
||||
}
|
||||
break;
|
||||
case REASONING_BUDGET_DONE:
|
||||
break;
|
||||
@@ -144,14 +146,6 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// advance to next forced token (done here rather than in accept so that
|
||||
// the first forced token isn't skipped when starting in FORCING state)
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
||||
}
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
||||
@@ -261,3 +255,10 @@ struct llama_sampler * common_reasoning_budget_init(
|
||||
common_reasoning_budget_state initial_state) {
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
|
||||
if (!smpl) {
|
||||
return REASONING_BUDGET_IDLE;
|
||||
}
|
||||
return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
|
||||
}
|
||||
|
||||
@@ -51,3 +51,5 @@ struct llama_sampler * common_reasoning_budget_init(
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state);
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
|
||||
|
||||
+46
-10
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
@@ -109,6 +110,7 @@ struct common_sampler {
|
||||
common_params_sampling params;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
struct llama_sampler * rbudget;
|
||||
struct llama_sampler * chain;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
@@ -188,6 +190,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
llama_sampler * grmr = nullptr;
|
||||
llama_sampler * rbudget = nullptr;
|
||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
@@ -270,7 +273,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
if (grmr) {
|
||||
if (grmr && !params.grammar_lazy) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
@@ -284,15 +287,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler — added first so it can force tokens before other samplers
|
||||
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
|
||||
samplers.push_back(common_reasoning_budget_init(
|
||||
// reasoning budget sampler
|
||||
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
|
||||
rbudget = common_reasoning_budget_init(
|
||||
vocab,
|
||||
params.reasoning_budget_start,
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens,
|
||||
prefill_tokens));
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
|
||||
prefill_tokens);
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
@@ -383,6 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .rbudget = */ rbudget,
|
||||
/* .chain = */ chain,
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
@@ -398,11 +402,27 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
}
|
||||
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
llama_sampler_free(gsmpl->rbudget);
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
delete gsmpl;
|
||||
}
|
||||
|
||||
static bool grammar_should_apply(struct common_sampler * gsmpl) {
|
||||
if (!gsmpl->grmr) {
|
||||
return false;
|
||||
}
|
||||
if (!gsmpl->rbudget) {
|
||||
return true;
|
||||
}
|
||||
if (gsmpl->params.grammar_lazy) {
|
||||
// if grammar is lazy, only apply when reasoning budget is not active
|
||||
const auto state = common_reasoning_budget_get_state(gsmpl->rbudget);
|
||||
return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
if (!gsmpl) {
|
||||
return;
|
||||
@@ -410,6 +430,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
|
||||
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
|
||||
|
||||
llama_sampler_accept(gsmpl->rbudget, token);
|
||||
|
||||
if (gsmpl->grmr && accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
@@ -431,6 +456,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
return new common_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .rbudget = */ llama_sampler_clone(gsmpl->rbudget),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
@@ -500,6 +526,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
llama_token id = LLAMA_TOKEN_NULL;
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & rbudget = gsmpl->rbudget;
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
@@ -511,7 +538,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
|
||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
|
||||
|
||||
// TODO: simplify
|
||||
gsmpl->cur.resize(1);
|
||||
@@ -524,7 +552,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
if (grammar_first) {
|
||||
// apply reasoning budget first
|
||||
llama_sampler_apply(rbudget, &cur_p);
|
||||
|
||||
if (grammar_first && grammar_should_apply(gsmpl)) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
||||
@@ -532,7 +563,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
|
||||
id = cur_p.data[cur_p.selected].id;
|
||||
|
||||
if (grammar_first) {
|
||||
if (grammar_first || !grammar_should_apply(gsmpl)) {
|
||||
return id;
|
||||
}
|
||||
|
||||
@@ -553,7 +584,12 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
llama_sampler_apply(rbudget, &cur_p);
|
||||
|
||||
if (grammar_should_apply(gsmpl)) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
||||
@@ -31,10 +31,10 @@ import gguf
|
||||
from gguf.vocab import MistralTokenizerType, MistralVocab
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # type: ignore[import-not-found, ty:unresolved-import]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # type: ignore[import-not-found, ty:unresolved-import]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found, ty:unresolved-import]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found, ty:unresolved-import]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
|
||||
# Add utils directory to path for direct script execution
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
|
||||
from common import get_model_name_from_env_path, compare_tokens, exit_with_warning # type: ignore[import-not-found]
|
||||
from common import get_model_name_from_env_path, compare_tokens, exit_with_warning # type: ignore[import-not-found, ty:unresolved-import]
|
||||
|
||||
def quick_logits_check(pytorch_file, llamacpp_file):
|
||||
"""Lightweight sanity check before NMSE"""
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from common import get_model_name_from_env_path # type: ignore[import-not-found]
|
||||
from common import get_model_name_from_env_path # type: ignore[import-not-found, ty:unresolved-import]
|
||||
|
||||
def calculate_nmse(reference, test):
|
||||
mse = np.mean((test - reference) ** 2)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from common import compare_tokens # type: ignore[import-not-found]
|
||||
from common import compare_tokens # type: ignore[import-not-found, ty:unresolved-import]
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
|
||||
@@ -7,7 +7,7 @@ import importlib
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel
|
||||
from common import compare_tokens, exit_with_warning # type: ignore[import-not-found]
|
||||
from common import compare_tokens, exit_with_warning # type: ignore[import-not-found, ty:unresolved-import]
|
||||
|
||||
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
|
||||
|
||||
|
||||
@@ -20,4 +20,4 @@ cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA
|
||||
#cmake --build . --config Release --target llama-bench
|
||||
|
||||
#build all binary
|
||||
cmake --build . --config Release -j -v
|
||||
cmake --build . --config Release -j$((($(nproc)+1)/2)) -v
|
||||
|
||||
@@ -23,9 +23,9 @@ if [ $# -gt 0 ]; then
|
||||
GGML_SYCL_DEVICE=$1
|
||||
echo "use $GGML_SYCL_DEVICE as main GPU"
|
||||
#use signle GPU only
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ${LOAD_MODE}
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 200 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ${LOAD_MODE}
|
||||
|
||||
else
|
||||
#use multiple GPUs with same max compute units
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} ${LOAD_MODE}
|
||||
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 200 -e -ngl ${NGL} -s 0 -c ${CONTEXT} ${LOAD_MODE}
|
||||
fi
|
||||
|
||||
@@ -47,9 +47,11 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
#ifdef STRIDED_ITERATOR_AVAILABLE
|
||||
auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
|
||||
#else
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
// offset_iterator needs to populate nrows + 1 elements, so we also have to ceildiv nrows + 1 by block_size
|
||||
const int nrows_offset = nrows + 1;
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows_offset);
|
||||
int * offset_iterator = offsets_alloc.get();
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
const dim3 offset_grid((nrows_offset + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
|
||||
#endif
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
@@ -2343,7 +2343,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
|
||||
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
|
||||
if (ggml_is_quantized(src0->type)) {
|
||||
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
|
||||
const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc);
|
||||
if (ne2 <= mmvq_mmid_max) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
@@ -2946,14 +2947,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
}
|
||||
|
||||
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
|
||||
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
|
||||
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
|
||||
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
|
||||
use_cuda_graph = false;
|
||||
if (node->op == GGML_OP_MUL_MAT_ID) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc);
|
||||
if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) {
|
||||
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
|
||||
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
|
||||
+342
-51
@@ -97,6 +97,194 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
||||
return MMVQ_PARAMETERS_GENERIC;
|
||||
}
|
||||
|
||||
// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID.
|
||||
// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE.
|
||||
// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details
|
||||
|
||||
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ1_S: return 6;
|
||||
case GGML_TYPE_IQ1_M: return 6;
|
||||
case GGML_TYPE_IQ2_S: return 4;
|
||||
case GGML_TYPE_IQ2_XS: return 5;
|
||||
case GGML_TYPE_IQ2_XXS: return 5;
|
||||
case GGML_TYPE_IQ3_S: return 4;
|
||||
case GGML_TYPE_IQ3_XXS: return 4;
|
||||
case GGML_TYPE_IQ4_NL: return 6;
|
||||
case GGML_TYPE_IQ4_XS: return 5;
|
||||
case GGML_TYPE_MXFP4: return 4;
|
||||
case GGML_TYPE_Q2_K: return 4;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_0: return 6;
|
||||
case GGML_TYPE_Q4_1: return 6;
|
||||
case GGML_TYPE_Q4_K: return 5;
|
||||
case GGML_TYPE_Q5_0: return 6;
|
||||
case GGML_TYPE_Q5_1: return 6;
|
||||
case GGML_TYPE_Q5_K: return 5;
|
||||
case GGML_TYPE_Q6_K: return 4;
|
||||
case GGML_TYPE_Q8_0: return 4;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_S: return 7;
|
||||
case GGML_TYPE_IQ3_S: return 6;
|
||||
case GGML_TYPE_IQ3_XXS: return 7;
|
||||
case GGML_TYPE_MXFP4: return 7;
|
||||
case GGML_TYPE_Q2_K: return 7;
|
||||
case GGML_TYPE_Q3_K: return 5;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ1_S: return 5;
|
||||
case GGML_TYPE_IQ1_M: return 5;
|
||||
case GGML_TYPE_IQ2_S: return 4;
|
||||
case GGML_TYPE_IQ2_XS: return 4;
|
||||
case GGML_TYPE_IQ2_XXS: return 4;
|
||||
case GGML_TYPE_IQ3_S: return 4;
|
||||
case GGML_TYPE_IQ3_XXS: return 4;
|
||||
case GGML_TYPE_IQ4_NL: return 6;
|
||||
case GGML_TYPE_IQ4_XS: return 4;
|
||||
case GGML_TYPE_Q2_K: return 4;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_0: return 5;
|
||||
case GGML_TYPE_Q4_1: return 5;
|
||||
case GGML_TYPE_Q4_K: return 4;
|
||||
case GGML_TYPE_Q5_K: return 4;
|
||||
case GGML_TYPE_Q6_K: return 4;
|
||||
case GGML_TYPE_Q8_0: return 4;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_S: return 5;
|
||||
case GGML_TYPE_IQ2_XS: return 5;
|
||||
case GGML_TYPE_IQ2_XXS: return 5;
|
||||
case GGML_TYPE_IQ3_S: return 4;
|
||||
case GGML_TYPE_IQ3_XXS: return 5;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_S: return 4;
|
||||
case GGML_TYPE_IQ2_XS: return 4;
|
||||
case GGML_TYPE_IQ2_XXS: return 4;
|
||||
case GGML_TYPE_IQ3_S: return 4;
|
||||
case GGML_TYPE_IQ3_XXS: return 4;
|
||||
case GGML_TYPE_Q2_K: return 7;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_K: return 5;
|
||||
case GGML_TYPE_Q5_K: return 6;
|
||||
case GGML_TYPE_Q6_K: return 5;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ1_S: return 6;
|
||||
case GGML_TYPE_IQ1_M: return 6;
|
||||
case GGML_TYPE_IQ2_S: return 4;
|
||||
case GGML_TYPE_IQ2_XS: return 4;
|
||||
case GGML_TYPE_IQ2_XXS: return 4;
|
||||
case GGML_TYPE_IQ3_S: return 4;
|
||||
case GGML_TYPE_IQ3_XXS: return 4;
|
||||
case GGML_TYPE_IQ4_NL: return 6;
|
||||
case GGML_TYPE_IQ4_XS: return 6;
|
||||
case GGML_TYPE_Q4_K: return 4;
|
||||
case GGML_TYPE_Q5_K: return 4;
|
||||
case GGML_TYPE_Q6_K: return 4;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ1_S: return 7;
|
||||
case GGML_TYPE_IQ1_M: return 7;
|
||||
case GGML_TYPE_IQ2_S: return 4;
|
||||
case GGML_TYPE_IQ2_XS: return 4;
|
||||
case GGML_TYPE_IQ2_XXS: return 4;
|
||||
case GGML_TYPE_IQ3_S: return 4;
|
||||
case GGML_TYPE_IQ3_XXS: return 4;
|
||||
case GGML_TYPE_IQ4_NL: return 7;
|
||||
case GGML_TYPE_IQ4_XS: return 5;
|
||||
case GGML_TYPE_MXFP4: return 5;
|
||||
case GGML_TYPE_Q3_K: return 4;
|
||||
case GGML_TYPE_Q4_0: return 7;
|
||||
case GGML_TYPE_Q4_1: return 7;
|
||||
case GGML_TYPE_Q4_K: return 4;
|
||||
case GGML_TYPE_Q5_0: return 7;
|
||||
case GGML_TYPE_Q5_1: return 7;
|
||||
case GGML_TYPE_Q5_K: return 5;
|
||||
case GGML_TYPE_Q6_K: return 5;
|
||||
case GGML_TYPE_Q8_0: return 7;
|
||||
default: return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
// Host function: returns the max batch size for the current arch+type at runtime.
|
||||
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
|
||||
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
|
||||
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
if (cc >= GGML_CUDA_CC_TURING) {
|
||||
return get_mmvq_mmid_max_batch_turing_plus(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
return get_mmvq_mmid_max_batch_pascal_older(type);
|
||||
}
|
||||
// AMD
|
||||
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna4(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna3(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return get_mmvq_mmid_max_batch_cdna(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_GCN(cc)) {
|
||||
return get_mmvq_mmid_max_batch_gcn(type);
|
||||
}
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
// Device constexpr: returns the max batch size for the current arch+type at compile time.
|
||||
template <ggml_type type>
|
||||
static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
|
||||
#if defined(RDNA4)
|
||||
return get_mmvq_mmid_max_batch_rdna4(type);
|
||||
#elif defined(RDNA3)
|
||||
return get_mmvq_mmid_max_batch_rdna3(type);
|
||||
#elif defined(RDNA2) || defined(RDNA1)
|
||||
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
|
||||
#elif defined(CDNA)
|
||||
return get_mmvq_mmid_max_batch_cdna(type);
|
||||
#elif defined(GCN)
|
||||
return get_mmvq_mmid_max_batch_gcn(type);
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE)
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
return get_mmvq_mmid_max_batch_turing_plus(type);
|
||||
#else
|
||||
return get_mmvq_mmid_max_batch_pascal_older(type);
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
|
||||
if (table_id == MMVQ_PARAMETERS_GENERIC) {
|
||||
switch (ncols_dst) {
|
||||
@@ -195,7 +383,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
|
||||
template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false>
|
||||
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||
@@ -222,22 +410,13 @@ static __global__ void mul_mat_vec_q(
|
||||
|
||||
const uint32_t channel_dst = blockIdx.y;
|
||||
|
||||
uint32_t token_idx = 0;
|
||||
uint32_t channel_x;
|
||||
uint32_t channel_y;
|
||||
uint32_t sample_dst;
|
||||
|
||||
if constexpr (is_multi_token_id) {
|
||||
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
|
||||
token_idx = blockIdx.z;
|
||||
channel_x = ids[channel_dst + token_idx * ids_stride];
|
||||
channel_y = fastmodulo(channel_dst, nchannels_y);
|
||||
sample_dst = 0;
|
||||
} else {
|
||||
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
|
||||
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
|
||||
sample_dst = blockIdx.z;
|
||||
}
|
||||
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
|
||||
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
|
||||
sample_dst = blockIdx.z;
|
||||
|
||||
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
||||
const uint32_t sample_y = sample_dst;
|
||||
@@ -294,9 +473,6 @@ static __global__ void mul_mat_vec_q(
|
||||
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||
|
||||
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
|
||||
if constexpr (is_multi_token_id) {
|
||||
y += token_idx*stride_col_y;
|
||||
}
|
||||
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
|
||||
|
||||
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
|
||||
@@ -350,10 +526,6 @@ static __global__ void mul_mat_vec_q(
|
||||
|
||||
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
|
||||
|
||||
if constexpr (is_multi_token_id) {
|
||||
dst += token_idx*stride_col_dst;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
@@ -413,6 +585,69 @@ static __global__ void mul_mat_vec_q(
|
||||
}
|
||||
}
|
||||
|
||||
// Dedicated MoE multi-token kernel.
|
||||
// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst)
|
||||
// Block: (warp_size, ncols_dst) - each warp handles one token independently.
|
||||
// No shared memory reduction needed since each warp works alone.
|
||||
template <ggml_type type, int c_rows_per_block>
|
||||
__launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mul_mat_vec_q_moe(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids,
|
||||
float * __restrict__ dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
|
||||
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
|
||||
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
|
||||
const uint32_t ncols_dst, const uint32_t ids_stride) {
|
||||
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
constexpr int vdr = get_vdr_mmvq(type);
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
|
||||
const uint32_t token_idx = threadIdx.y;
|
||||
const int row0 = c_rows_per_block*blockIdx.x;
|
||||
const int blocks_per_row_x = ncols_x / qk;
|
||||
constexpr int blocks_per_iter = vdr * warp_size / qi;
|
||||
|
||||
const uint32_t channel_dst = blockIdx.y;
|
||||
|
||||
if (token_idx >= ncols_dst) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride];
|
||||
const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y);
|
||||
|
||||
const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y;
|
||||
const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp[c_rows_per_block] = {0.0f};
|
||||
|
||||
for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
|
||||
const int kby = kbx * (qk/QK8_1);
|
||||
const int kqs = vdr * (threadIdx.x % (qi/vdr));
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < c_rows_per_block; ++i) {
|
||||
tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||
}
|
||||
}
|
||||
|
||||
// Warp-level reduction only - no shared memory needed
|
||||
#pragma unroll
|
||||
for (int i = 0; i < c_rows_per_block; ++i) {
|
||||
tmp[i] = warp_reduce_sum<warp_size>(tmp[i]);
|
||||
}
|
||||
|
||||
// Write results
|
||||
if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) {
|
||||
dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x];
|
||||
}
|
||||
}
|
||||
|
||||
template<ggml_type type>
|
||||
static std::pair<dim3, dim3> calc_launch_params(
|
||||
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
|
||||
@@ -425,7 +660,7 @@ static std::pair<dim3, dim3> calc_launch_params(
|
||||
return {block_nums, block_dims};
|
||||
}
|
||||
|
||||
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
|
||||
template<ggml_type type, int c_ncols_dst, bool small_k = false>
|
||||
static void mul_mat_vec_q_switch_fusion(
|
||||
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||
@@ -438,7 +673,7 @@ static void mul_mat_vec_q_switch_fusion(
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (c_ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
mul_mat_vec_q<type, c_ncols_dst, true, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
||||
@@ -448,12 +683,33 @@ static void mul_mat_vec_q_switch_fusion(
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
mul_mat_vec_q<type, c_ncols_dst, false, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_moe_launch(
|
||||
const void * vx, const void * vy, const int32_t * ids, float * dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
|
||||
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
|
||||
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
|
||||
const uint32_t ncols_dst, const uint32_t ids_stride,
|
||||
const int warp_size, const int nchannels_dst, cudaStream_t stream) {
|
||||
|
||||
constexpr int rows_per_block = 2; // 2 gives best perf based on tuning
|
||||
const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block;
|
||||
const dim3 block_nums(nblocks_rows, nchannels_dst);
|
||||
const dim3 block_dims(warp_size, ncols_dst);
|
||||
|
||||
mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>(
|
||||
vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x,
|
||||
stride_row_x, stride_col_y, stride_col_dst,
|
||||
stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
ncols_dst, ids_stride);
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_switch_ncols_dst(
|
||||
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
@@ -472,20 +728,62 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
||||
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
|
||||
const mmvq_parameter_table_id table_id = get_device_table_id(cc);
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
const bool has_ids = ids != nullptr;
|
||||
|
||||
const auto should_use_small_k = [&](int c_ncols_dst) {
|
||||
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
|
||||
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
constexpr int vdr = get_vdr_mmvq(type);
|
||||
const int blocks_per_row_x = ncols_x / qk;
|
||||
const int blocks_per_iter_1warp = vdr * warp_size / qi;
|
||||
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
|
||||
bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
|
||||
|
||||
constexpr std::array<ggml_type, 2> iq_slow_turing = {
|
||||
GGML_TYPE_IQ3_XXS,
|
||||
GGML_TYPE_IQ3_S,
|
||||
};
|
||||
constexpr std::array<ggml_type, 8> iq_slow_other = {
|
||||
GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
|
||||
GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
||||
};
|
||||
constexpr std::array<ggml_type, 3> slow_pascal = {
|
||||
GGML_TYPE_IQ3_S,
|
||||
GGML_TYPE_Q2_K,
|
||||
GGML_TYPE_Q3_K,
|
||||
};
|
||||
|
||||
const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING;
|
||||
const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA;
|
||||
|
||||
if (is_nvidia_turing_plus) {
|
||||
if (ncols_dst == 1 &&
|
||||
std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) {
|
||||
use = false;
|
||||
}
|
||||
} else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) ||
|
||||
(is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) ||
|
||||
GGML_CUDA_CC_IS_RDNA(cc)) {
|
||||
use = false;
|
||||
}
|
||||
|
||||
return use;
|
||||
};
|
||||
|
||||
if (has_ids && ncols_dst > 1) {
|
||||
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
|
||||
constexpr int c_ncols_dst = 1;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, ids_stride, stream);
|
||||
// Multi-token MUL_MAT_ID path - dedicated MoE kernel
|
||||
mul_mat_vec_q_moe_launch<type>(
|
||||
vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x,
|
||||
stride_row_x, stride_col_y, stride_col_dst,
|
||||
stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
ncols_dst, ids_stride, warp_size, nchannels_dst, stream);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -493,31 +791,24 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
||||
case 1: {
|
||||
constexpr int c_ncols_dst = 1;
|
||||
|
||||
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
|
||||
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
constexpr int vdr = get_vdr_mmvq(type);
|
||||
const int blocks_per_row_x = ncols_x / qk;
|
||||
const int blocks_per_iter_1warp = vdr * warp_size / qi;
|
||||
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
|
||||
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
|
||||
bool use_small_k = should_use_small_k(c_ncols_dst);
|
||||
|
||||
if (use_small_k) {
|
||||
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
|
||||
warp_size, table_id, true);
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
|
||||
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
|
||||
nsamples_dst, warp_size, table_id, true);
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(
|
||||
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, ids_stride, stream);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
|
||||
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
|
||||
stream);
|
||||
} else {
|
||||
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
|
||||
warp_size, table_id);
|
||||
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
|
||||
nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
|
||||
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, ids_stride, stream);
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
|
||||
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
|
||||
stream);
|
||||
}
|
||||
} break;
|
||||
case 2: {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
|
||||
|
||||
// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID,
|
||||
// based on the quantization type and GPU architecture (compute capability).
|
||||
int get_mmvq_mmid_max_batch(ggml_type type, int cc);
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
@@ -346,6 +346,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
|
||||
|
||||
dma_cache m_cache;
|
||||
dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
||||
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
|
||||
@@ -389,9 +392,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
|
||||
// Mask is 1D contiguous for this row
|
||||
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
||||
dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
||||
}
|
||||
|
||||
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
|
||||
@@ -554,7 +556,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
|
||||
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
|
||||
dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
|
||||
}
|
||||
|
||||
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
|
||||
@@ -684,7 +686,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
octx->src0_spad.size_per_thread = size_q_block * 1;
|
||||
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
|
||||
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0;
|
||||
octx->dst_spad.size_per_thread = size_vkq_acc;
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
||||
@@ -705,6 +707,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
||||
|
||||
// FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread);
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
|
||||
desc->desc_size = 0; // 1D mode
|
||||
desc->src_bypass = dma_src_l2_bypass_on;
|
||||
desc->dst_bypass = dma_dst_l2_bypass_on;
|
||||
desc->order = 1;
|
||||
desc->order = 0;
|
||||
desc->done = 0;
|
||||
desc->src = (void *) dptr.src;
|
||||
desc->dst = (void *) dptr.dst;
|
||||
@@ -151,8 +151,12 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
|
||||
|
||||
q->dptr[q->push_idx] = dptr;
|
||||
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = (dma_descriptor_2d *) desc;
|
||||
if (size) {
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = (dma_descriptor_2d *) desc;
|
||||
} else {
|
||||
desc->done = 1;
|
||||
}
|
||||
|
||||
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
|
||||
q->push_idx = (q->push_idx + 1) & q->idx_mask;
|
||||
@@ -175,7 +179,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
|
||||
desc->dst_bypass = dma_dst_l2_bypass_on;
|
||||
desc->src_comp = 0;
|
||||
desc->dst_comp = 0;
|
||||
desc->order = 1;
|
||||
desc->order = 0;
|
||||
desc->done = 0;
|
||||
desc->src_stride = src_stride;
|
||||
desc->dst_stride = dst_stride;
|
||||
@@ -197,8 +201,12 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
|
||||
|
||||
q->dptr[q->push_idx] = dptr;
|
||||
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = desc;
|
||||
if (nrows) {
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = desc;
|
||||
} else {
|
||||
desc->done = 1;
|
||||
}
|
||||
|
||||
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
|
||||
q->push_idx = (q->push_idx + 1) & q->idx_mask;
|
||||
@@ -215,12 +223,9 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
|
||||
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
|
||||
|
||||
// Wait for desc to complete
|
||||
while (1) {
|
||||
dmpoll();
|
||||
if (desc->done) {
|
||||
break;
|
||||
}
|
||||
while (!desc->done) {
|
||||
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
|
||||
dmpoll();
|
||||
}
|
||||
|
||||
dptr = q->dptr[q->pop_idx];
|
||||
@@ -312,6 +317,54 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
|
||||
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
|
||||
}
|
||||
|
||||
#define DMA_CACHE_MAX_SIZE 64U
|
||||
|
||||
typedef struct {
|
||||
uint8_t *base;
|
||||
uint32_t line_size;
|
||||
uint32_t capacity;
|
||||
uint32_t src[DMA_CACHE_MAX_SIZE];
|
||||
uint16_t age[DMA_CACHE_MAX_SIZE];
|
||||
} dma_cache;
|
||||
|
||||
static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity)
|
||||
{
|
||||
c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity;
|
||||
c->base = base;
|
||||
c->line_size = line_size;
|
||||
|
||||
for (unsigned i=0; i < c->capacity; i++) {
|
||||
c->src[i] = 0;
|
||||
c->age[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows)
|
||||
{
|
||||
uint32_t o_idx = 0;
|
||||
uint16_t o_age = 0;
|
||||
uint8_t * dst = 0;
|
||||
|
||||
for (unsigned i=0; i < c->capacity; i++) {
|
||||
if (c->src[i] == (uint32_t) src) {
|
||||
c->age[i] = 0;
|
||||
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
|
||||
// FARF(ERROR, "dma-cache: found %p", src);
|
||||
} else {
|
||||
c->age[i]++;
|
||||
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
|
||||
}
|
||||
}
|
||||
if (!dst) {
|
||||
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
|
||||
c->age[o_idx] = 0;
|
||||
c->src[o_idx] = (uint32_t) src;
|
||||
dst = c->base + o_idx * c->line_size; // normal nrows dma
|
||||
}
|
||||
|
||||
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
@@ -333,8 +333,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
}
|
||||
|
||||
// Skip DMA transactions from prev block (if any)
|
||||
// No need to wait for these since the DMA is setup for in-order processing
|
||||
// Skip output DMA transactions from prev block (if any)
|
||||
// No need to wait for those here since we're explicitly waiting for the latest prefecthes below.
|
||||
for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
|
||||
|
||||
// Compute loop
|
||||
|
||||
@@ -1340,7 +1340,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
||||
if (buffer && buffer->iface.init_tensor) {
|
||||
buffer->iface.init_tensor(buffer, tensor);
|
||||
} else {
|
||||
GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
|
||||
if (!buffer) {
|
||||
GGML_LOG_ERROR("Tensor with null buffer passed to init_tensor function\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor->extra != nullptr) {
|
||||
|
||||
@@ -1112,6 +1112,16 @@ struct vk_op_glu_push_constants {
|
||||
uint32_t mode; // 0: default, 1: swapped, 2: split
|
||||
float alpha; // for swiglu_oai
|
||||
float limit;
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t ne01;
|
||||
uint32_t ne02;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t ne11;
|
||||
uint32_t ne12;
|
||||
};
|
||||
|
||||
struct vk_op_unary_push_constants {
|
||||
@@ -5044,7 +5054,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
} else {
|
||||
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
|
||||
}
|
||||
vk::DeviceCreateInfo device_create_info;
|
||||
vk::DeviceCreateInfo device_create_info{};
|
||||
std::vector<const char *> device_extensions;
|
||||
vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
|
||||
|
||||
@@ -5413,12 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
#endif
|
||||
device->name = GGML_VK_NAME + std::to_string(idx);
|
||||
|
||||
device_create_info = {
|
||||
vk::DeviceCreateFlags(),
|
||||
device_queue_create_infos,
|
||||
{},
|
||||
device_extensions
|
||||
};
|
||||
device_create_info
|
||||
.setFlags(vk::DeviceCreateFlags())
|
||||
.setQueueCreateInfos(device_queue_create_infos)
|
||||
.setPEnabledExtensionNames(device_extensions);
|
||||
device_create_info.setPNext(&device_features2);
|
||||
device->device = device->physical_device.createDevice(device_create_info);
|
||||
|
||||
@@ -11048,8 +11056,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
const float alpha = op_params_f[2];
|
||||
const float limit = op_params_f[3];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
if (!split) {
|
||||
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
|
||||
} else {
|
||||
@@ -11067,7 +11073,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
(uint32_t)dst->ne[0],
|
||||
mode,
|
||||
alpha,
|
||||
limit
|
||||
limit,
|
||||
(uint32_t)(src0->nb[1] / src0->nb[0]),
|
||||
(uint32_t)(src0->nb[2] / src0->nb[0]),
|
||||
(uint32_t)(src0->nb[3] / src0->nb[0]),
|
||||
(uint32_t)src0->ne[1],
|
||||
(uint32_t)src0->ne[2],
|
||||
(uint32_t)(dst->nb[1] / dst->nb[0]),
|
||||
(uint32_t)(dst->nb[2] / dst->nb[0]),
|
||||
(uint32_t)(dst->nb[3] / dst->nb[0]),
|
||||
(uint32_t)dst->ne[1],
|
||||
(uint32_t)dst->ne[2]
|
||||
});
|
||||
}
|
||||
|
||||
@@ -15217,8 +15233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_GLU_OP_SWIGLU_OAI:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||
(op->src[0]->type == op->type);
|
||||
default:
|
||||
|
||||
@@ -16,4 +16,14 @@ layout (push_constant) uniform parameter
|
||||
uint mode;
|
||||
float alpha;
|
||||
float limit;
|
||||
uint nb01;
|
||||
uint nb02;
|
||||
uint nb03;
|
||||
uint ne01;
|
||||
uint ne02;
|
||||
uint nb11;
|
||||
uint nb12;
|
||||
uint nb13;
|
||||
uint ne11;
|
||||
uint ne12;
|
||||
} p;
|
||||
|
||||
@@ -8,22 +8,32 @@ void main() {
|
||||
const uint row = i / p.ne20;
|
||||
const uint col = i - row * p.ne20;
|
||||
|
||||
const uint i3 = row / (p.ne01 * p.ne02);
|
||||
const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01;
|
||||
const uint i1 = row % p.ne01;
|
||||
const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col;
|
||||
|
||||
const uint dst_i3 = row / (p.ne11 * p.ne12);
|
||||
const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11;
|
||||
const uint dst_i1 = row % p.ne11;
|
||||
const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col;
|
||||
|
||||
if (p.mode == 0) {
|
||||
// Default
|
||||
const uint offset = p.ne00 / 2;
|
||||
const uint idx = row * p.ne00 + col;
|
||||
const uint idx = src_idx;
|
||||
|
||||
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
|
||||
data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
|
||||
} else if (p.mode == 1) {
|
||||
// Swapped
|
||||
const uint offset = p.ne00 / 2;
|
||||
const uint idx = row * p.ne00 + col;
|
||||
const uint idx = src_idx;
|
||||
|
||||
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
|
||||
data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
|
||||
} else {
|
||||
// Split
|
||||
const uint idx = row * p.ne00 + col;
|
||||
const uint idx = src_idx;
|
||||
|
||||
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
|
||||
data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,12 +14,12 @@ except ImportError:
|
||||
SentencePieceProcessor: Any = None
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # type: ignore[import-not-found, ty:unresolved-import]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found, ty:unresolved-import]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found, ty:unresolved-import]
|
||||
_filter_valid_tokenizer_files,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found, ty:unresolved-import]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -32,7 +32,7 @@ else:
|
||||
_mistral_common_installed = True
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found, ty:unresolved-import]
|
||||
get_one_valid_tokenizer_file,
|
||||
)
|
||||
except ImportError:
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
{%- set image_count = namespace(value=0) %}
|
||||
{%- set video_count = namespace(value=0) %}
|
||||
{%- macro render_content(content, do_vision_count, is_system_content=false) %}
|
||||
{%- if content is string %}
|
||||
{{- content }}
|
||||
{%- elif content is iterable and content is not mapping %}
|
||||
{%- for item in content %}
|
||||
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
|
||||
{%- if is_system_content %}
|
||||
{{- raise_exception('System message cannot contain images.') }}
|
||||
{%- endif %}
|
||||
{%- if do_vision_count %}
|
||||
{%- set image_count.value = image_count.value + 1 %}
|
||||
{%- endif %}
|
||||
{%- if add_vision_id %}
|
||||
{{- 'Picture ' ~ image_count.value ~ ': ' }}
|
||||
{%- endif %}
|
||||
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
|
||||
{%- elif 'video' in item or item.type == 'video' %}
|
||||
{%- if is_system_content %}
|
||||
{{- raise_exception('System message cannot contain videos.') }}
|
||||
{%- endif %}
|
||||
{%- if do_vision_count %}
|
||||
{%- set video_count.value = video_count.value + 1 %}
|
||||
{%- endif %}
|
||||
{%- if add_vision_id %}
|
||||
{{- 'Video ' ~ video_count.value ~ ': ' }}
|
||||
{%- endif %}
|
||||
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
|
||||
{%- elif 'text' in item %}
|
||||
{{- item.text }}
|
||||
{%- else %}
|
||||
{{- raise_exception('Unexpected item type in content.') }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- elif content is none or content is undefined %}
|
||||
{{- '' }}
|
||||
{%- else %}
|
||||
{{- raise_exception('Unexpected content type.') }}
|
||||
{%- endif %}
|
||||
{%- endmacro %}
|
||||
{%- if not messages %}
|
||||
{{- raise_exception('No messages provided.') }}
|
||||
{%- endif %}
|
||||
{%- if tools and tools is iterable and tools is not mapping %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>" }}
|
||||
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{%- set content = render_content(messages[0].content, false, true)|trim %}
|
||||
{%- if content %}
|
||||
{{- '\n\n' + content }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{%- set content = render_content(messages[0].content, false, true)|trim %}
|
||||
{{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||
{%- for message in messages[::-1] %}
|
||||
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||
{%- if ns.multi_step_tool and message.role == "user" %}
|
||||
{%- set content = render_content(message.content, false)|trim %}
|
||||
{%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
|
||||
{%- set ns.multi_step_tool = false %}
|
||||
{%- set ns.last_query_index = index %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if ns.multi_step_tool %}
|
||||
{{- raise_exception('No user query found in messages.') }}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- set content = render_content(message.content, true)|trim %}
|
||||
{%- if message.role == "system" %}
|
||||
{%- if not loop.first %}
|
||||
{{- raise_exception('System message must be at the beginning.') }}
|
||||
{%- endif %}
|
||||
{%- elif message.role == "user" %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- if message.reasoning_content is string %}
|
||||
{%- set reasoning_content = message.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set reasoning_content = reasoning_content|trim %}
|
||||
{%- if loop.index0 > ns.last_query_index %}
|
||||
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{%- if loop.first %}
|
||||
{%- if content|trim %}
|
||||
{{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
||||
{%- else %}
|
||||
{{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.arguments is defined %}
|
||||
{%- for args_name, args_value in tool_call.arguments|items %}
|
||||
{{- '<parameter=' + args_name + '>\n' }}
|
||||
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
||||
{{- args_value }}
|
||||
{{- '\n</parameter>\n' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif loop.last %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- raise_exception('Unexpected message role.') }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||
{{- '<think>\n\n</think>\n\n' }}
|
||||
{%- else %}
|
||||
{{- '<think>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
@@ -147,7 +147,7 @@ ranges_nfd: list[tuple[int, int, int]] = [(0, 0, 0)] # start, last, nfd
|
||||
for codepoint, norm in table_nfd:
|
||||
start = ranges_nfd[-1][0]
|
||||
if ranges_nfd[-1] != (start, codepoint - 1, norm):
|
||||
ranges_nfd.append(None) # type: ignore[arg-type] # dummy, will be replaced below
|
||||
ranges_nfd.append((0, 0, 0)) # dummy, will be replaced below
|
||||
start = codepoint
|
||||
ranges_nfd[-1] = (start, codepoint, norm)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "refs/tags/v0.39.0"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.40.0"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
||||
@@ -557,6 +557,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_ROPE_FREQS,
|
||||
LLM_TENSOR_ROPE_FACTORS_LONG,
|
||||
LLM_TENSOR_ROPE_FACTORS_SHORT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_K,
|
||||
|
||||
@@ -1158,6 +1158,12 @@ struct ggml_tensor * llama_model_loader::create_tensor(
|
||||
if (overrides->buft == ggml_backend_cpu_buffer_type()) {
|
||||
// when overriding to a CPU buffer, consider the extra buffer types
|
||||
buft = select_weight_buft(hparams, t_meta, op, buft_list_cpu);
|
||||
if (use_mmap) {
|
||||
static std::once_flag once;
|
||||
std::call_once(once, [] {
|
||||
LLAMA_LOG_WARN("llama_model_loader: tensor overrides to CPU are used with mmap enabled - consider using --no-mmap for better performance\n");
|
||||
});
|
||||
}
|
||||
} else {
|
||||
buft = overrides->buft;
|
||||
}
|
||||
|
||||
@@ -8424,6 +8424,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 256, 1, 1}, order)); // test ceildiv in CUDA's CUB's DeviceSegmentedSort
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
|
||||
@@ -1330,7 +1330,7 @@ static void test_nemotron_reasoning_detection(testing & t) {
|
||||
analysis.analyze_template(tmpl);
|
||||
|
||||
// Check reasoning markers
|
||||
t.assert_equal("reasoning_start should be '<think>'", "<think>", analysis.reasoning.start);
|
||||
t.assert_equal("reasoning_start should be '<think>\\n'", "<think>\n", analysis.reasoning.start);
|
||||
t.assert_equal("reasoning_end should be '</think>'", "</think>", analysis.reasoning.end);
|
||||
|
||||
// Check reasoning mode detection
|
||||
|
||||
+524
-81
@@ -425,6 +425,7 @@ static common_chat_tool special_function_tool_with_optional_param{
|
||||
"required": ["arg1"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool empty_args_tool{
|
||||
/* .name = */ "empty_args",
|
||||
/* .description = */ "A tool that takes no arguments",
|
||||
@@ -433,6 +434,15 @@ static common_chat_tool empty_args_tool{
|
||||
"properties": {}
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool empty_args_tool_no_properties{
|
||||
/* .name = */ "empty_args_no_props",
|
||||
/* .description = */ "A tool that takes no arguments and has no properties",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object"
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool python_tool{
|
||||
/* .name = */ "python",
|
||||
/* .description = */ "an ipython interpreter",
|
||||
@@ -805,7 +815,8 @@ struct peg_test_case {
|
||||
common_chat_templates_inputs params;
|
||||
std::string input;
|
||||
common_chat_msg expect;
|
||||
bool is_partial = false;
|
||||
bool is_partial = false;
|
||||
bool expect_reconstruction = false;
|
||||
};
|
||||
|
||||
struct make_peg_parser {
|
||||
@@ -828,6 +839,12 @@ struct make_peg_parser {
|
||||
}
|
||||
};
|
||||
|
||||
// Global template filter for --template flag
|
||||
static std::string g_template_filter;
|
||||
|
||||
// When true, run reconstruction test on every non-partial test and report results
|
||||
static bool g_force_reconstruction_test = false;
|
||||
|
||||
static void test_peg_parser(common_chat_templates * tmpls,
|
||||
const std::function<void(peg_test_case &)> & init,
|
||||
bool detailed_debug) {
|
||||
@@ -936,75 +953,158 @@ static void test_peg_parser(common_chat_templates * tmpls,
|
||||
throw std::runtime_error("Failed to build grammar: " + parser.params_.grammar);
|
||||
}
|
||||
|
||||
// Find the earliest trigger position to determine the constrained portion
|
||||
auto earliest_trigger_pos = std::string::npos;
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
size_t pos = std::string::npos;
|
||||
std::smatch match;
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
pos = tc.input.find(word);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
const auto & pattern = std::regex(trigger.value);
|
||||
if (std::regex_search(tc.input, match, pattern)) {
|
||||
pos = match.position(pattern.mark_count());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
if (std::regex_match(tc.input, match, std::regex(pattern))) {
|
||||
auto mpos = std::string::npos;
|
||||
for (size_t i = 1; i < match.size(); ++i) {
|
||||
if (match[i].length() > 0) {
|
||||
mpos = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (mpos == std::string::npos) {
|
||||
mpos = match.position(0);
|
||||
}
|
||||
pos = mpos;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unknown trigger type");
|
||||
}
|
||||
if (pos != std::string::npos) {
|
||||
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
||||
earliest_trigger_pos = pos;
|
||||
// In production, grammar triggers match against the full generated text
|
||||
// including the generation prompt. All positions are in full_input coordinates.
|
||||
const auto & gen_prompt = parser.params_.generation_prompt;
|
||||
std::string full_input = gen_prompt + tc.input;
|
||||
|
||||
// Determine whether the reasoning-budget sampler path applies: tool-call grammar
|
||||
// with all WORD triggers and thinking tags present. In production, the reasoning
|
||||
// budget sampler inhibits grammar application while inside thinking blocks —
|
||||
// triggers inside <think>...</think> are suppressed.
|
||||
bool use_reasoning_budget_path = false;
|
||||
if (parser.params_.grammar_lazy && !parser.params_.thinking_end_tag.empty()) {
|
||||
use_reasoning_budget_path = true;
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
if (trigger.type != COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||
use_reasoning_budget_path = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the constrained portion of input to test against grammar
|
||||
std::string constrained = tc.input;
|
||||
// Find the earliest trigger position to determine the constrained portion
|
||||
auto earliest_trigger_pos = std::string::npos;
|
||||
|
||||
if (use_reasoning_budget_path) {
|
||||
// Reasoning-budget path: simulate thinking-aware trigger detection.
|
||||
// Walk through full_input tracking thinking state; only match triggers
|
||||
// when outside thinking blocks.
|
||||
const auto & think_start = parser.params_.thinking_start_tag;
|
||||
const auto & think_end = parser.params_.thinking_end_tag;
|
||||
|
||||
bool in_thinking = false;
|
||||
for (size_t i = 0; i < full_input.size(); ++i) {
|
||||
if (!in_thinking && !think_start.empty()
|
||||
&& full_input.compare(i, think_start.size(), think_start) == 0) {
|
||||
in_thinking = true;
|
||||
i += think_start.size() - 1;
|
||||
continue;
|
||||
}
|
||||
if (in_thinking && full_input.compare(i, think_end.size(), think_end) == 0) {
|
||||
in_thinking = false;
|
||||
i += think_end.size() - 1;
|
||||
continue;
|
||||
}
|
||||
if (in_thinking) {
|
||||
continue;
|
||||
}
|
||||
// Outside thinking — check if any trigger word starts here
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
if (full_input.compare(i, trigger.value.size(), trigger.value) == 0) {
|
||||
if (earliest_trigger_pos == std::string::npos || i < earliest_trigger_pos) {
|
||||
earliest_trigger_pos = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (earliest_trigger_pos != std::string::npos) {
|
||||
break; // found the earliest
|
||||
}
|
||||
}
|
||||
|
||||
// If the reasoning-budget path found no trigger outside thinking but the test
|
||||
// expects tool calls, this template nests tool calls inside thinking
|
||||
// blocks (e.g. Kimi). Fall back to the legacy path for this case.
|
||||
if (earliest_trigger_pos == std::string::npos && !tc.expect.tool_calls.empty()) {
|
||||
use_reasoning_budget_path = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!use_reasoning_budget_path) {
|
||||
// Legacy path: find triggers without thinking-awareness
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
size_t pos = std::string::npos;
|
||||
std::smatch match;
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
pos = full_input.find(word);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
const auto & compiled = std::regex(trigger.value);
|
||||
if (std::regex_search(full_input, match, compiled)) {
|
||||
pos = match.position(compiled.mark_count());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
// In production, PATTERN_FULL triggers are checked against
|
||||
// the text generated so far, growing token by token. Simulate
|
||||
// by trying every prefix of full_input.
|
||||
const auto & compiled = std::regex(trigger.value);
|
||||
for (size_t end = gen_prompt.size(); end <= full_input.size(); ++end) {
|
||||
std::string prefix = full_input.substr(0, end);
|
||||
if (std::regex_match(prefix, match, compiled)) {
|
||||
pos = std::string::npos;
|
||||
for (size_t gi = 1; gi < match.size(); ++gi) {
|
||||
if (match[gi].length() > 0) {
|
||||
pos = match.position(gi);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pos == std::string::npos) {
|
||||
pos = match.position(0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unknown trigger type");
|
||||
}
|
||||
if (pos != std::string::npos) {
|
||||
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
||||
earliest_trigger_pos = pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the test expects tool calls and the grammar is lazy, the trigger must fire.
|
||||
// Otherwise the grammar would never activate in production and tool calls wouldn't
|
||||
// be constrained. A silent skip here would hide broken triggers.
|
||||
if (parser.params_.grammar_lazy && !tc.expect.tool_calls.empty() && !tc.is_partial
|
||||
&& earliest_trigger_pos == std::string::npos) {
|
||||
std::string trigger_desc;
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
trigger_desc += "\n [type=" + std::to_string(trigger.type) + "] " + trigger.value;
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"Grammar trigger did not fire, but test expects tool calls (lazy grammar).\n"
|
||||
">>> Input: " + full_input + "\n"
|
||||
">>> Triggers (" + std::to_string(parser.params_.grammar_triggers.size()) + "):" + trigger_desc);
|
||||
}
|
||||
|
||||
// Determine the constrained portion of input to test against grammar.
|
||||
// If the trigger position falls inside the generation prompt, the grammar
|
||||
// sampler was already active before model output began — constrain from the
|
||||
// start of the model output (i.e. tc.input).
|
||||
std::string constrained = full_input;
|
||||
bool grammar_triggered = false;
|
||||
if (earliest_trigger_pos != std::string::npos) {
|
||||
constrained = tc.input.substr(earliest_trigger_pos);
|
||||
auto constrain_from = std::max(earliest_trigger_pos, gen_prompt.size());
|
||||
constrained = full_input.substr(constrain_from);
|
||||
grammar_triggered = true;
|
||||
} else if (!parser.params_.grammar_lazy) {
|
||||
// For non-lazy grammars, the entire input should match
|
||||
grammar_triggered = true;
|
||||
}
|
||||
|
||||
// For non-lazy grammars, prepend reasoning prefill to grammar input, just like
|
||||
// PEG parsing does. The grammar includes the full reasoning pattern (e.g. optional
|
||||
// <think>...</think>), but the model output may start mid-reasoning if the template
|
||||
// already placed the opening tag in the prompt.
|
||||
// For lazy grammars, the grammar only activates from the trigger position, so the
|
||||
// reasoning prefill is irrelevant — reasoning is handled by the PEG parser.
|
||||
if (!parser.params_.generation_prompt.empty() && earliest_trigger_pos == std::string::npos) {
|
||||
constrained = parser.params_.generation_prompt + constrained;
|
||||
}
|
||||
|
||||
// Test the constrained portion against the grammar
|
||||
if (grammar_triggered && !tc.is_partial) {
|
||||
auto result = match_string_detailed(constrained, grammar.get());
|
||||
@@ -1036,10 +1136,57 @@ static void test_peg_parser(common_chat_templates * tmpls,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Global template filter for --template flag
|
||||
static std::string g_template_filter;
|
||||
// Reconstruction test: verify that appending the parsed message to the original
|
||||
// messages and re-rendering the template (without generation prompt) reproduces
|
||||
// the original prompt + input exactly, or as a proper prefix (the template may
|
||||
// append end-of-turn tokens after the assistant message).
|
||||
if ((tc.expect_reconstruction || g_force_reconstruction_test) && !tc.is_partial) {
|
||||
// Start from tc.expect but copy tool call arguments from the actual parser
|
||||
// output, which preserves original JSON formatting (e.g. {"arg1":1} vs {"arg1": 1}).
|
||||
auto reconstruction_msg = tc.expect;
|
||||
auto parsed_msg = parser.parse(tc.input, false);
|
||||
for (size_t i = 0; i < reconstruction_msg.tool_calls.size() && i < parsed_msg.tool_calls.size(); i++) {
|
||||
reconstruction_msg.tool_calls[i].arguments = parsed_msg.tool_calls[i].arguments;
|
||||
}
|
||||
common_chat_templates_inputs reconstruction_inputs = tc.params;
|
||||
reconstruction_inputs.messages.push_back(reconstruction_msg);
|
||||
reconstruction_inputs.add_generation_prompt = false;
|
||||
|
||||
auto reconstruction_params = common_chat_templates_apply(tmpls, reconstruction_inputs);
|
||||
std::string expected_text = parser.params_.prompt + tc.input;
|
||||
bool match = reconstruction_params.prompt == expected_text ||
|
||||
(reconstruction_params.prompt.size() > expected_text.size() &&
|
||||
reconstruction_params.prompt.compare(0, expected_text.size(), expected_text) == 0);
|
||||
if (!match && g_force_reconstruction_test && !tc.expect_reconstruction) {
|
||||
// In forced mode, report mismatch but don't fail
|
||||
// Find the first difference position
|
||||
size_t diff_pos = 0;
|
||||
size_t min_len = std::min(expected_text.size(), reconstruction_params.prompt.size());
|
||||
while (diff_pos < min_len && expected_text[diff_pos] == reconstruction_params.prompt[diff_pos]) {
|
||||
diff_pos++;
|
||||
}
|
||||
size_t ctx_start = diff_pos > 60 ? diff_pos - 60 : 0;
|
||||
size_t ctx_end_e = std::min(expected_text.size(), diff_pos + 40);
|
||||
size_t ctx_end_r = std::min(reconstruction_params.prompt.size(), diff_pos + 40);
|
||||
LOG_ERR("\x1b[31m[RECONSTRUCTION FAIL]\x1b[0m "
|
||||
"first diff at byte %zu (expected len=%zu, reconstructed len=%zu)\n"
|
||||
" expected: ...%s...\n"
|
||||
" reconstructed: ...%s...\n",
|
||||
diff_pos, expected_text.size(), reconstruction_params.prompt.size(),
|
||||
expected_text.substr(ctx_start, ctx_end_e - ctx_start).c_str(),
|
||||
reconstruction_params.prompt.substr(ctx_start, ctx_end_r - ctx_start).c_str());
|
||||
} else if (!match) {
|
||||
std::string error_msg =
|
||||
"Reconstruction mismatch:\n\n"
|
||||
">>> Expected (prompt + input):\n" + expected_text +
|
||||
"\n\n>>> Reconstructed:\n" + reconstruction_params.prompt;
|
||||
throw std::runtime_error(error_msg);
|
||||
} else if (g_force_reconstruction_test) {
|
||||
LOG_INF("\x1b[32m[RECONSTRUCTION OK]\x1b[0m\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fluent builder for PEG parser tests
|
||||
class peg_test_builder;
|
||||
@@ -1099,6 +1246,11 @@ class peg_test_builder {
|
||||
return *this;
|
||||
}
|
||||
|
||||
peg_test_builder & expect_reconstruction(bool val = true) {
|
||||
tc_.expect_reconstruction = val;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Expect setters
|
||||
peg_test_builder & expect(const common_chat_msg & msg) {
|
||||
tc_.expect = msg;
|
||||
@@ -1268,20 +1420,192 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
}
|
||||
})";
|
||||
|
||||
{
|
||||
// Qwen3.5 (basically same as Nemotron, but keeping separate tests just in case)
|
||||
auto tst = peg_tester("models/templates/Qwen3.5-4B.jinja", detailed_debug);
|
||||
|
||||
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.expect(message_assist_thoughts)
|
||||
.run();
|
||||
|
||||
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_NONE)
|
||||
.expect_content("<think>\nI'm\nthinking\n</think>\nHello, world!\nWhat's up?")
|
||||
.run();
|
||||
|
||||
tst.test("I'm\nthinking\n</think>\nHello, world!\nWhat's up?")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist_thoughts)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"<function=special_function>\n"
|
||||
"<parameter=arg1>\n1\n</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"I'm\nthinking\n</think>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=special_function>\n"
|
||||
"<parameter=arg1>\n1\n</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_thoughts)
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"<function=special_function>\n"
|
||||
"<parameter=arg1>\n1\n</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=special_function_with_opt>\n"
|
||||
"<parameter=arg1>\n1\n</parameter>\n"
|
||||
"<parameter=arg2>\n2\n</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.parallel_tool_calls(true)
|
||||
.tools({
|
||||
special_function_tool, special_function_tool_with_optional_param
|
||||
})
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({
|
||||
python_tool
|
||||
})
|
||||
.expect_tool_calls({
|
||||
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"I need to output the invoice details in JSON\n"
|
||||
"</think>\n"
|
||||
R"({"amount": 123.45, "date": "2025-12-03"})")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.json_schema(invoice_schema)
|
||||
.expect_reasoning("I need to output the invoice details in JSON")
|
||||
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
|
||||
.run();
|
||||
|
||||
// tool call segment in reasoning
|
||||
tst.test(
|
||||
"Let's call a tool: <tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Not the real call!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call></think>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>"
|
||||
)
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({
|
||||
python_tool
|
||||
})
|
||||
.expect_reasoning("Let's call a tool: <tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Not the real call!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.expect_tool_calls({
|
||||
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// No args tool
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"<function=empty_args>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ empty_args_tool })
|
||||
.expect(message_with_tool_calls("empty_args", "{}"))
|
||||
.run();
|
||||
|
||||
// No args tool with no properties defined
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"<function=empty_args_no_props>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.enable_thinking(false)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ empty_args_tool_no_properties })
|
||||
.expect(message_with_tool_calls("empty_args_no_props", "{}"))
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
// Ministral-3-14B-Reasoning-2512
|
||||
auto tst = peg_tester("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
|
||||
tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
|
||||
.expect_content("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.expect(message_assist_thoughts)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})")
|
||||
@@ -1311,6 +1635,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -1323,6 +1648,20 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_reasoning("I need to output the invoice details in JSON")
|
||||
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
|
||||
.run();
|
||||
|
||||
// fake tool call marker in reasoning
|
||||
tst.test(
|
||||
"[THINK]Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more[/THINK]"
|
||||
R"([TOOL_CALLS]special_function[ARGS]{"arg1": 1})")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.tools({ special_function_tool })
|
||||
.expect_reasoning("Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more")
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1425,6 +1764,50 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_reasoning("I need to output the invoice details in JSON")
|
||||
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
|
||||
.run();
|
||||
|
||||
// tool call segment in reasoning
|
||||
tst.test(
|
||||
"Let's call a tool: <tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Not the real call!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call></think>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>"
|
||||
)
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({
|
||||
python_tool
|
||||
})
|
||||
.expect_reasoning("Let's call a tool: <tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Not the real call!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.expect_tool_calls({
|
||||
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1481,9 +1864,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Google Gemma 2 2B - does not support tool calling
|
||||
auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja");
|
||||
|
||||
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run();
|
||||
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).expect_reconstruction().run();
|
||||
|
||||
tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).run();
|
||||
tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1526,7 +1909,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Test simple content-only template
|
||||
auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
}
|
||||
{
|
||||
// IBM Granite (reasoning and tool calling model)
|
||||
@@ -1638,7 +2021,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Qwen3-Coder (tool calling with XML-style format)
|
||||
auto tst = peg_tester("models/templates/Qwen3-Coder.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
@@ -1650,6 +2033,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
"</tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -1678,6 +2062,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test with code content (multiline)
|
||||
@@ -1698,6 +2083,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test with code content (asian unicode chars)
|
||||
@@ -1715,6 +2101,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "python", "{\"code\": \"格\"}", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test with HTML tag content
|
||||
@@ -1736,6 +2123,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "html", "{\"markup\": \"<html>\\n <head>\\n <title>Hello!</title>\\n </head>\\n</html>\"}", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test with TODO list (array of objects)
|
||||
@@ -1753,6 +2141,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "todo_list", "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test flexible optional argument ordering (2 required + 4 optional, reversed optional order)
|
||||
@@ -1769,6 +2158,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "tool_2req_4opt", R"({"req1": "hello", "req2": 42, "opt4": 100, "opt2": 200})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test flexible optional argument ordering (2 required + 5 optional, reversed optional order)
|
||||
@@ -1786,6 +2176,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "tool_2req_5opt", R"({"req1": "world", "req2": 7, "opt5": "last", "opt3": "middle", "opt1": "first"})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test flexible optional argument ordering (2 required + 5 optional, all 5 in shuffled order)
|
||||
@@ -1805,6 +2196,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "tool_2req_5opt", R"({"req1": "test", "req2": 99, "opt3": "c", "opt1": "a", "opt5": "e", "opt4": 4, "opt2": 2})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
{
|
||||
@@ -1885,6 +2277,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
tst.test("Hello, world!\nWhat's up?")
|
||||
.enable_thinking(false)
|
||||
.expect(message_assist)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Reasoning with content (forced-open mode - input starts after <think>)
|
||||
@@ -1892,6 +2285,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.expect(message_assist_thoughts)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Tool call without reasoning
|
||||
@@ -1902,6 +2296,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.enable_thinking(false)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Tool call with reasoning (forced-open mode)
|
||||
@@ -1914,6 +2309,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_thoughts)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -1933,6 +2329,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// #20650: tool with no required args, model emits <tool_call>name</tool_call> with no arg tags.
|
||||
@@ -1950,6 +2347,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.tools({ no_args_tool })
|
||||
.expect_reasoning("Let me read the diff content.")
|
||||
.expect_tool_calls({{ "read_file_diff_md", "{}", {} }})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
}
|
||||
@@ -2208,22 +2606,24 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
|
||||
// Kimi-K2 old template
|
||||
auto tst = peg_tester("models/templates/moonshotai-Kimi-K2.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test(
|
||||
"<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>"
|
||||
"{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(kimi_id_special_func_tool_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Kimi-K2-Instruct
|
||||
auto tst2 = peg_tester("models/templates/Kimi-K2-Instruct.jinja", detailed_debug);
|
||||
tst2.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst2.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst2.test(
|
||||
"<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>"
|
||||
"{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(kimi_id_special_func_tool_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
@@ -2297,6 +2697,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.tools({ empty_args_tool })
|
||||
.expect(simple_assist_msg("", "", "empty_args", "{}"))
|
||||
.run();
|
||||
|
||||
// fake tool call marker in reasoning
|
||||
tst.test(
|
||||
"<think>Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm</think>"
|
||||
"<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect_reasoning("Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm")
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
})
|
||||
.run();
|
||||
}
|
||||
|
||||
// Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format
|
||||
@@ -2306,6 +2719,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
tst.test("<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
@@ -2314,7 +2728,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{
|
||||
auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug);
|
||||
tst.test(
|
||||
"</think><minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"<minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
@@ -2364,37 +2778,41 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// mistralai-Mistral-Nemo-Instruct-2407.jinja
|
||||
{
|
||||
auto tst = peg_tester("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_id)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
{
|
||||
auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.1.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("<function=special_function>{\"arg1\": 1}</function>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
// Functionary v3.2 - recipient-based format: >>>recipient\n{content}
|
||||
{
|
||||
auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.2.jinja", detailed_debug);
|
||||
tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("special_function\n{\"arg1\": 1}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
// FireFunction
|
||||
{
|
||||
auto tst = peg_tester("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test(" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
@@ -2455,10 +2873,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{ "models/templates/MiMo-VL.jinja", "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja",
|
||||
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja" }) {
|
||||
auto tst = peg_tester(path, detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n</tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
@@ -2481,6 +2900,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.expect(simple_assist_msg("Hello, world!\nWhat's up?", "Here are my reasoning steps:\nI'm\nthinking"))
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Reasoning + Tool calls
|
||||
@@ -2497,42 +2917,45 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Mistral Small 3.2 - FUNC_BRACKET_TAG format: [TOOL_CALLS]func_name[CALL_ID]id[ARGS]{...}
|
||||
{
|
||||
auto tst = peg_tester("models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("[TOOL_CALLS]special_function[CALL_ID]123456789[ARGS]{\"arg1\": 1}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_id)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
// Devstral
|
||||
{
|
||||
auto tst = peg_tester("models/templates/unsloth-mistral-Devstral-Small-2507.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
tst.test("Hello, world!\nWhat's up?[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_content)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
// Llama 3.1
|
||||
auto tst = peg_tester("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
{
|
||||
// Llama 3.2
|
||||
auto tst = peg_tester("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
{
|
||||
// Llama 3.3
|
||||
auto tst = peg_tester("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
// GPT-OSS format tests
|
||||
@@ -2553,6 +2976,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect(message_assist_thoughts)
|
||||
.run();
|
||||
|
||||
// Analysis channel (reasoning) with final channel (content) with reasoning_format = none
|
||||
tst.test(
|
||||
"<|channel|>analysis<|message|>I'm\nthinking<|end|><|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's "
|
||||
"up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_NONE)
|
||||
.expect_content("<|channel|>analysis<|message|>I'm\nthinking<|end|>Hello, world!\nWhat's up?")
|
||||
.run();
|
||||
|
||||
// Analysis channel only (partial) - still works when reasoning format is set
|
||||
tst.test("<|channel|>analysis<|message|>I'm\nthinking")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
@@ -2562,24 +2993,28 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
|
||||
// Tool call with recipient in role header: " to=functions.NAME<|channel|>analysis<|message|>JSON"
|
||||
tst.test(" to=functions.special_function<|channel|>analysis<|message|>{\"arg1\": 1}")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
|
||||
// Tool call with recipient in channel header: "<|channel|>analysis to=functions.NAME<|message|>JSON"
|
||||
tst.test("<|channel|>analysis to=functions.special_function<|message|>{\"arg1\": 1}")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
|
||||
// Tool call with constraint: " to=functions.NAME<|channel|>analysis <|constrain|>json<|message|>JSON"
|
||||
tst.test(" to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
|
||||
// Tool call in commentary channel (channel header variant)
|
||||
tst.test("<|channel|>commentary to=functions.special_function<|message|>{\"arg1\": 1}")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
@@ -2836,10 +3271,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// GigaChat V3
|
||||
{
|
||||
auto tst = peg_tester("models/templates/GigaChat3-10B-A1.8B.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("<|message_sep|>\n\nfunction call<|role_sep|>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -2848,16 +3284,18 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_content)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
// GigaChat V3.1
|
||||
{
|
||||
auto tst = peg_tester("models/templates/GigaChat3.1-10B-A1.8B.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("<|function_call|>{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -2866,6 +3304,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_content)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
}
|
||||
@@ -3002,6 +3441,10 @@ int main(int argc, char ** argv) {
|
||||
detailed_debug = true;
|
||||
common_log_set_verbosity_thold(999);
|
||||
}
|
||||
if (arg == "--force-reconstruction-test") {
|
||||
g_force_reconstruction_test = true;
|
||||
only_run_filtered = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (only_run_filtered) {
|
||||
|
||||
@@ -387,6 +387,24 @@ static void test_expressions(testing & t) {
|
||||
"Bob"
|
||||
);
|
||||
|
||||
test_template(t, "empty computed member defaults to undefined",
|
||||
"{{ a[]|default('fallback') }}",
|
||||
{{"a", {{"name", "Bob"}}}},
|
||||
"fallback"
|
||||
);
|
||||
|
||||
test_template(t, "empty computed member is undefined",
|
||||
"{{ a[] is undefined }}",
|
||||
{{"a", {{"name", "Bob"}}}},
|
||||
"True"
|
||||
);
|
||||
|
||||
test_template(t, "undefined computed member is undefined",
|
||||
"{{ a[undefined] is undefined }}",
|
||||
{{"a", {{"name", "Bob"}}}},
|
||||
"True"
|
||||
);
|
||||
|
||||
test_template(t, "array access",
|
||||
"{{ items[1] }}",
|
||||
{{"items", json::array({"a", "b", "c"})}},
|
||||
|
||||
@@ -1525,6 +1525,47 @@ int main() {
|
||||
}
|
||||
});
|
||||
|
||||
// C++ only tests (features not yet supported in JS/Python implementations)
|
||||
{
|
||||
fprintf(stderr, "#\n# Testing C++ only features\n#\n");
|
||||
auto run = [](const TestCase & tc) {
|
||||
fprintf(stderr, "- %s\n", tc.name.c_str());
|
||||
try {
|
||||
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
|
||||
tc.verify_status(SUCCESS);
|
||||
} catch (const std::invalid_argument & ex) {
|
||||
fprintf(stderr, "Error: %s\n", ex.what());
|
||||
tc.verify_status(FAILURE);
|
||||
}
|
||||
};
|
||||
|
||||
run({
|
||||
SUCCESS,
|
||||
"regexp with non-capturing group",
|
||||
R"""({
|
||||
"type": "string",
|
||||
"pattern": "^(?:foo|bar)baz$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" (("foo" | "bar") "baz") "\"" space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""",
|
||||
});
|
||||
|
||||
run({
|
||||
SUCCESS,
|
||||
"regexp with nested non-capturing groups",
|
||||
R"""({
|
||||
"type": "string",
|
||||
"pattern": "^(?:(?:ab)+c)?d$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ((("ab")+ "c")? "d") "\"" space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""",
|
||||
});
|
||||
}
|
||||
|
||||
if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) {
|
||||
fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m");
|
||||
} else {
|
||||
|
||||
@@ -61,8 +61,6 @@ static void test_reasoning_budget(
|
||||
|
||||
// Feed the sequence and track when forcing occurs
|
||||
for (size_t i = 0; i < sequence.size(); i++) {
|
||||
llama_sampler_accept(sampler, sequence[i]);
|
||||
|
||||
// Check if we're in forcing state by applying and seeing if logits are modified
|
||||
cur_p.selected = -1;
|
||||
for (size_t j = 0; j < cur.size(); j++) {
|
||||
@@ -81,6 +79,8 @@ static void test_reasoning_budget(
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_accept(sampler, sequence[i]);
|
||||
|
||||
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
|
||||
|
||||
if (finite_count == 1) {
|
||||
@@ -167,9 +167,9 @@ int main(void) {
|
||||
}
|
||||
|
||||
// Test 2: Budget exhausted, forcing should occur
|
||||
// Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING
|
||||
// Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state)
|
||||
// At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE
|
||||
// Flow: i=0 apply()->passthrough, accept(100)->COUNTING; i=1 accept(50)->remaining=1
|
||||
// i=2 accept(51)->remaining=0->FORCING; i=3 apply() forces token[0]; i=4 apply() forces token[1]
|
||||
// At i=4, accept() advances force_pos to 2 which equals forced_tokens.size(), so state becomes DONE
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
@@ -179,13 +179,12 @@ int main(void) {
|
||||
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_IDLE,
|
||||
2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces)
|
||||
3); // forcing continues through i=3 (at i=4 state becomes DONE)
|
||||
3, // forcing starts at i=3 (accept at i=2 depletes budget, apply at i=3 forces)
|
||||
4); // forcing continues through i=4 (accept at i=4 transitions to DONE)
|
||||
}
|
||||
|
||||
// Test 3: Activate immediately with budget=0, forcing should start right away
|
||||
// Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough)
|
||||
// This test needs start token to be in the sequence or use activate_immediately with start token present
|
||||
// Flow: init promotes COUNTING+budget=0 to FORCING, so apply() sees FORCING at i=0
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
@@ -195,8 +194,8 @@ int main(void) {
|
||||
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
|
||||
0, // budget of 0 tokens
|
||||
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
|
||||
0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING)
|
||||
1); // forcing continues through i=1 (at i=2 state becomes DONE)
|
||||
0, // forcing starts at i=0 (initialized in FORCING, apply forces immediately)
|
||||
1); // forcing continues through i=1 (accept at i=1 transitions to DONE)
|
||||
}
|
||||
|
||||
// Test 4: No start/end tokens configured - passthrough (no forcing)
|
||||
@@ -214,7 +213,7 @@ int main(void) {
|
||||
|
||||
// Test 5: Activate immediately with budget > 0, count down then force
|
||||
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
|
||||
// So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0)
|
||||
// Forcing starts at i=2 (apply sees FORCING after accept at i=1 transitioned)
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
@@ -224,8 +223,8 @@ int main(void) {
|
||||
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_COUNTING,
|
||||
1, // forcing starts at i=1 (after 2 accepts deplete budget)
|
||||
2); // forcing continues through i=2
|
||||
2, // forcing starts at i=2 (after 2 accepts deplete budget, apply at i=2 forces)
|
||||
3); // forcing continues through i=3
|
||||
}
|
||||
|
||||
printf("OK (5 tests passed)\n");
|
||||
|
||||
+81
-18
@@ -100,7 +100,7 @@ struct cli_context {
|
||||
}
|
||||
|
||||
// reasoning budget sampler
|
||||
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
|
||||
if (!chat_params.thinking_end_tag.empty()) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(
|
||||
llama_get_model(ctx_server.get_llama_context()));
|
||||
|
||||
@@ -224,10 +224,11 @@ struct cli_context {
|
||||
};
|
||||
|
||||
// TODO?: Make this reusable, enums, docs
|
||||
static const std::array<const std::string, 6> cmds = {
|
||||
static const std::array<const std::string, 7> cmds = {
|
||||
"/audio ",
|
||||
"/clear",
|
||||
"/exit",
|
||||
"/glob ",
|
||||
"/image ",
|
||||
"/read ",
|
||||
"/regen",
|
||||
@@ -258,7 +259,7 @@ static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std:
|
||||
}
|
||||
}
|
||||
|
||||
if (!cmd.empty() && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
|
||||
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
|
||||
auto cur_dir = std::filesystem::current_path();
|
||||
@@ -339,6 +340,8 @@ static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std:
|
||||
return matches;
|
||||
}
|
||||
|
||||
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
@@ -430,7 +433,8 @@ int main(int argc, char ** argv) {
|
||||
console::log(" /exit or Ctrl+C stop or exit\n");
|
||||
console::log(" /regen regenerate the last response\n");
|
||||
console::log(" /clear clear the chat history\n");
|
||||
console::log(" /read add a text file\n");
|
||||
console::log(" /read <file> add a text file\n");
|
||||
console::log(" /glob <pattern> add text files using globbing pattern\n");
|
||||
if (inf.has_inp_image) {
|
||||
console::log(" /image <file> add an image file\n");
|
||||
}
|
||||
@@ -441,6 +445,27 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// interactive loop
|
||||
std::string cur_msg;
|
||||
|
||||
auto add_text_file = [&](const std::string & fname) -> bool {
|
||||
std::string marker = ctx_cli.load_input_file(fname, false);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
return false;
|
||||
}
|
||||
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
|
||||
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
|
||||
cur_msg += fname;
|
||||
cur_msg.push_back('\n');
|
||||
} else {
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded text from '%s'\n", fname.c_str());
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
std::string buffer;
|
||||
console::set_display(DISPLAY_TYPE_USER_INPUT);
|
||||
@@ -525,22 +550,60 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
std::string marker = ctx_cli.load_input_file(fname, false);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
continue;
|
||||
add_text_file(fname);
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/glob ")) {
|
||||
std::error_code ec;
|
||||
size_t count = 0;
|
||||
auto curdir = std::filesystem::current_path();
|
||||
std::string pattern = string_strip(buffer.substr(6));
|
||||
std::filesystem::path rel_path;
|
||||
|
||||
auto startglob = pattern.find_first_of("![*?");
|
||||
if (startglob != std::string::npos && startglob != 0) {
|
||||
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
||||
if (endpath != std::string::npos) {
|
||||
std::string rel_pattern = pattern.substr(0, endpath);
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(rel_pattern, "~")) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
rel_pattern = std::string(home) + rel_pattern.substr(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rel_path = rel_pattern;
|
||||
pattern.erase(0, endpath + 1);
|
||||
curdir /= rel_path;
|
||||
}
|
||||
}
|
||||
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
|
||||
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
|
||||
cur_msg += fname;
|
||||
cur_msg.push_back('\n');
|
||||
} else {
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
|
||||
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
||||
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(pattern, rel)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!add_text_file((rel_path / rel).string())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
||||
console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS);
|
||||
break;
|
||||
}
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded text from '%s'\n", fname.c_str());
|
||||
continue;
|
||||
} else {
|
||||
// not a command
|
||||
|
||||
@@ -188,6 +188,7 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--tags STRING` | set model tags, comma-separated (informational, not used for routing)<br/>(env: LLAMA_ARG_TAGS) |
|
||||
| `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
|
||||
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
|
||||
| `--reuse-port` | allow multiple sockets to bind to the same port (default: disabled)<br/>(env: LLAMA_ARG_REUSE_PORT) |
|
||||
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
|
||||
| `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )<br/>(env: LLAMA_ARG_API_PREFIX) |
|
||||
| `--webui-config JSON` | JSON that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG) |
|
||||
@@ -1774,6 +1775,16 @@ Apart from error types supported by OAI, we also have custom types that are spec
|
||||
}
|
||||
```
|
||||
|
||||
### Custom default Web UI preferences
|
||||
|
||||
You can specify default preferences for the web UI using `--webui-config <JSON config>` or `--webui-config-file <path to JSON config>`. For example, you can disable pasting long text as attachments and enable rendering Markdown in user messages with this command:
|
||||
|
||||
```bash
|
||||
./llama-server -m model.gguf --webui-config '{"pasteLongTextToFileLen": 0, "renderUserContentAsMarkdown": true}'
|
||||
```
|
||||
|
||||
You may find available preferences in [settings-config.ts](webui/src/lib/constants/settings-config.ts).
|
||||
|
||||
### Legacy completion web UI
|
||||
|
||||
A new chat-based UI has replaced the old completion-based since [this PR](https://github.com/ggml-org/llama.cpp/pull/10175). If you want to use the old completion, start the server with `--path ./tools/server/public_legacy`
|
||||
|
||||
Binary file not shown.
@@ -1110,7 +1110,7 @@ json oaicompat_chat_params_parse(
|
||||
reasoning_budget = json_value(body, "thinking_budget_tokens", -1);
|
||||
}
|
||||
|
||||
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
|
||||
if (!chat_params.thinking_end_tag.empty()) {
|
||||
llama_params["reasoning_budget_tokens"] = reasoning_budget;
|
||||
llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
|
||||
llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
|
||||
|
||||
@@ -2493,7 +2493,7 @@ private:
|
||||
bool has_mtmd = false;
|
||||
|
||||
// check if we should process the image
|
||||
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||
// process the image
|
||||
size_t n_tokens_out = 0;
|
||||
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
||||
|
||||
@@ -32,13 +32,22 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
|
||||
|
||||
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
|
||||
|
||||
std::map<std::string, std::string> headers;
|
||||
for (auto [key, value] : req.headers) {
|
||||
auto new_key = key;
|
||||
if (string_starts_with(new_key, "X-Proxy-Header-")) {
|
||||
string_replace_all(new_key, "X-Proxy-Header-", "");
|
||||
}
|
||||
headers[new_key] = value;
|
||||
}
|
||||
|
||||
auto proxy = std::make_unique<server_http_proxy>(
|
||||
method,
|
||||
parsed_url.scheme,
|
||||
parsed_url.host,
|
||||
parsed_url.port,
|
||||
parsed_url.path,
|
||||
req.headers,
|
||||
headers,
|
||||
req.body,
|
||||
req.should_stop,
|
||||
600, // timeout_read (default to 10 minutes)
|
||||
|
||||
@@ -112,6 +112,16 @@ bool server_http_context::init(const common_params & params) {
|
||||
// set timeouts and change hostname and port
|
||||
srv->set_read_timeout (params.timeout_read);
|
||||
srv->set_write_timeout(params.timeout_write);
|
||||
srv->set_socket_options([reuse_port = params.reuse_port](socket_t sock) {
|
||||
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
|
||||
if (reuse_port) {
|
||||
#ifdef SO_REUSEPORT
|
||||
httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEPORT, 1);
|
||||
#else
|
||||
LOG_WRN("%s: SO_REUSEPORT is not supported\n", __func__);
|
||||
#endif
|
||||
}
|
||||
});
|
||||
|
||||
if (params.api_keys.size() == 1) {
|
||||
auto key = params.api_keys[0];
|
||||
|
||||
@@ -35,7 +35,7 @@ using server_http_res_ptr = std::unique_ptr<server_http_res>;
|
||||
|
||||
struct server_http_req {
|
||||
std::map<std::string, std::string> params; // path_params + query_params
|
||||
std::map<std::string, std::string> headers; // reserved for future use
|
||||
std::map<std::string, std::string> headers; // used by MCP proxy
|
||||
std::string path;
|
||||
std::string query_string; // query parameters string (e.g. "action=save")
|
||||
std::string body;
|
||||
|
||||
@@ -478,19 +478,17 @@ task_params server_task::params_from_json_cmpl(
|
||||
// Parse reasoning budget sampler parameters
|
||||
{
|
||||
const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
|
||||
if (budget >= 0) {
|
||||
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
|
||||
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
||||
const auto message = json_value(data, "reasoning_budget_message", std::string());
|
||||
params.sampling.reasoning_budget_tokens = budget;
|
||||
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
|
||||
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
||||
const auto message = json_value(data, "reasoning_budget_message", std::string());
|
||||
params.sampling.reasoning_budget_tokens = budget;
|
||||
|
||||
if (!start_tag.empty()) {
|
||||
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
|
||||
}
|
||||
if (!end_tag.empty()) {
|
||||
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
|
||||
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
|
||||
}
|
||||
if (!start_tag.empty()) {
|
||||
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
|
||||
}
|
||||
if (!end_tag.empty()) {
|
||||
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
|
||||
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
|
||||
|
||||
SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
|
||||
budget, params.sampling.generation_prompt.c_str(),
|
||||
|
||||
@@ -101,38 +101,6 @@ static run_proc_result run_process(
|
||||
return res;
|
||||
}
|
||||
|
||||
// simple glob: * matches non-/ chars, ** matches anything including /
|
||||
static bool glob_match(const char * pattern, const char * str) {
|
||||
if (*pattern == '\0') {
|
||||
return *str == '\0';
|
||||
}
|
||||
if (pattern[0] == '*' && pattern[1] == '*') {
|
||||
const char * p = pattern + 2;
|
||||
if (*p == '/') p++;
|
||||
if (glob_match(p, str)) return true;
|
||||
if (*str != '\0') return glob_match(pattern, str + 1);
|
||||
return false;
|
||||
}
|
||||
if (*pattern == '*') {
|
||||
const char * p = pattern + 1;
|
||||
for (; *str != '\0' && *str != '/'; str++) {
|
||||
if (glob_match(p, str)) return true;
|
||||
}
|
||||
return glob_match(p, str);
|
||||
}
|
||||
if (*pattern == '?' && *str != '\0' && *str != '/') {
|
||||
return glob_match(pattern + 1, str + 1);
|
||||
}
|
||||
if (*pattern == *str) {
|
||||
return glob_match(pattern + 1, str + 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool glob_match(const std::string & pattern, const std::string & str) {
|
||||
return glob_match(pattern.c_str(), str.c_str());
|
||||
}
|
||||
|
||||
json server_tool::to_json() {
|
||||
return {
|
||||
{"display_name", display_name},
|
||||
|
||||
@@ -116,7 +116,7 @@ class ServerProcess:
|
||||
self.server_port = int(os.environ["PORT"])
|
||||
self.external_server = "DEBUG_EXTERNAL" in os.environ
|
||||
|
||||
def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
|
||||
def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None:
|
||||
if self.external_server:
|
||||
print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}")
|
||||
return
|
||||
@@ -288,7 +288,15 @@ class ServerProcess:
|
||||
server_instances.remove(self)
|
||||
if self.process:
|
||||
print(f"Stopping server with pid={self.process.pid}")
|
||||
self.process.kill()
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"Server pid={self.process.pid} did not terminate in time, killing")
|
||||
self.process.kill()
|
||||
self.process.wait(timeout=5)
|
||||
except Exception as e:
|
||||
print(f"Error waiting for server: {e}")
|
||||
self.process = None
|
||||
|
||||
def make_request(
|
||||
|
||||
+5
-11
@@ -10,9 +10,9 @@
|
||||
ModelsSelector,
|
||||
ModelsSelectorSheet
|
||||
} from '$lib/components/app';
|
||||
import { DialogChatSettings } from '$lib/components/app/dialogs';
|
||||
import { SETTINGS_SECTION_TITLES } from '$lib/constants';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { getChatSettingsDialogContext } from '$lib/contexts';
|
||||
import { FileTypeCategory } from '$lib/enums';
|
||||
import { getFileTypeCategory } from '$lib/utils';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
@@ -169,7 +169,7 @@
|
||||
selectorModelRef?.open();
|
||||
}
|
||||
|
||||
let showChatSettingsDialogWithMcpSection = $state(false);
|
||||
const chatSettingsDialog = getChatSettingsDialogContext();
|
||||
|
||||
let hasMcpPromptsSupport = $derived.by(() => {
|
||||
const perChatOverrides = conversationsStore.getAllMcpServerOverrides();
|
||||
@@ -197,7 +197,7 @@
|
||||
{onSystemPromptClick}
|
||||
{onMcpPromptClick}
|
||||
{onMcpResourcesClick}
|
||||
onMcpSettingsClick={() => (showChatSettingsDialogWithMcpSection = true)}
|
||||
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
|
||||
/>
|
||||
{:else}
|
||||
<ChatFormActionAttachmentsDropdown
|
||||
@@ -210,13 +210,13 @@
|
||||
{onSystemPromptClick}
|
||||
{onMcpPromptClick}
|
||||
{onMcpResourcesClick}
|
||||
onMcpSettingsClick={() => (showChatSettingsDialogWithMcpSection = true)}
|
||||
onMcpSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<McpServersSelector
|
||||
{disabled}
|
||||
onSettingsClick={() => (showChatSettingsDialogWithMcpSection = true)}
|
||||
onSettingsClick={() => chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -265,9 +265,3 @@
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<DialogChatSettings
|
||||
open={showChatSettingsDialogWithMcpSection}
|
||||
onOpenChange={(open) => (showChatSettingsDialogWithMcpSection = open)}
|
||||
initialSection={SETTINGS_SECTION_TITLES.MCP}
|
||||
/>
|
||||
|
||||
@@ -180,6 +180,10 @@
|
||||
chatActions.continueAssistantMessage(message);
|
||||
}
|
||||
|
||||
function handleForkConversation(options: { name: string; includeAttachments: boolean }) {
|
||||
chatActions.forkConversation(message, options);
|
||||
}
|
||||
|
||||
function handleNavigateToSibling(siblingId: string) {
|
||||
chatActions.navigateToSibling(siblingId);
|
||||
}
|
||||
@@ -285,6 +289,7 @@
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onForkConversation={handleForkConversation}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
@@ -303,6 +308,7 @@
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
onForkConversation={handleForkConversation}
|
||||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onRegenerate={handleRegenerate}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
|
||||
+67
-1
@@ -1,12 +1,16 @@
|
||||
<script lang="ts">
|
||||
import { Edit, Copy, RefreshCw, Trash2, ArrowRight } from '@lucide/svelte';
|
||||
import { Edit, Copy, RefreshCw, Trash2, ArrowRight, GitBranch } from '@lucide/svelte';
|
||||
import {
|
||||
ActionIcon,
|
||||
ChatMessageBranchingControls,
|
||||
DialogConfirmation
|
||||
} from '$lib/components/app';
|
||||
import { Switch } from '$lib/components/ui/switch';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
import { activeConversation } from '$lib/stores/conversations.svelte';
|
||||
|
||||
interface Props {
|
||||
role: MessageRole.USER | MessageRole.ASSISTANT;
|
||||
@@ -24,6 +28,7 @@
|
||||
onEdit?: () => void;
|
||||
onRegenerate?: () => void;
|
||||
onContinue?: () => void;
|
||||
onForkConversation?: (options: { name: string; includeAttachments: boolean }) => void;
|
||||
onDelete: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
@@ -42,6 +47,7 @@
|
||||
onConfirmDelete,
|
||||
onContinue,
|
||||
onDelete,
|
||||
onForkConversation,
|
||||
onNavigateToSibling,
|
||||
onShowDeleteDialogChange,
|
||||
onRegenerate,
|
||||
@@ -53,10 +59,27 @@
|
||||
onRawOutputToggle
|
||||
}: Props = $props();
|
||||
|
||||
let showForkDialog = $state(false);
|
||||
let forkName = $state('');
|
||||
let forkIncludeAttachments = $state(true);
|
||||
|
||||
function handleConfirmDelete() {
|
||||
onConfirmDelete();
|
||||
onShowDeleteDialogChange(false);
|
||||
}
|
||||
|
||||
function handleOpenForkDialog() {
|
||||
const conv = activeConversation();
|
||||
|
||||
forkName = `Fork of ${conv?.name ?? 'Conversation'}`;
|
||||
forkIncludeAttachments = true;
|
||||
showForkDialog = true;
|
||||
}
|
||||
|
||||
function handleConfirmFork() {
|
||||
onForkConversation?.({ name: forkName.trim(), includeAttachments: forkIncludeAttachments });
|
||||
showForkDialog = false;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative {justify === 'start' ? 'mt-2' : ''} flex h-6 items-center justify-between">
|
||||
@@ -86,6 +109,10 @@
|
||||
<ActionIcon icon={ArrowRight} tooltip="Continue" onclick={onContinue} />
|
||||
{/if}
|
||||
|
||||
{#if onForkConversation}
|
||||
<ActionIcon icon={GitBranch} tooltip="Fork conversation" onclick={handleOpenForkDialog} />
|
||||
{/if}
|
||||
|
||||
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
</div>
|
||||
</div>
|
||||
@@ -116,3 +143,42 @@
|
||||
onConfirm={handleConfirmDelete}
|
||||
onCancel={() => onShowDeleteDialogChange(false)}
|
||||
/>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showForkDialog}
|
||||
title="Fork Conversation"
|
||||
description="Create a new conversation branching from this message."
|
||||
confirmText="Fork"
|
||||
cancelText="Cancel"
|
||||
icon={GitBranch}
|
||||
onConfirm={handleConfirmFork}
|
||||
onCancel={() => (showForkDialog = false)}
|
||||
>
|
||||
<div class="flex flex-col gap-4 py-2">
|
||||
<div class="flex flex-col gap-2">
|
||||
<Label for="fork-name">Title</Label>
|
||||
|
||||
<Input
|
||||
id="fork-name"
|
||||
class="text-foreground"
|
||||
placeholder="Enter fork name"
|
||||
type="text"
|
||||
bind:value={forkName}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Checkbox
|
||||
id="fork-attachments"
|
||||
checked={forkIncludeAttachments}
|
||||
onCheckedChange={(checked) => {
|
||||
forkIncludeAttachments = checked === true;
|
||||
}}
|
||||
/>
|
||||
|
||||
<Label for="fork-attachments" class="cursor-pointer text-sm font-normal">
|
||||
Include all attachments
|
||||
</Label>
|
||||
</div>
|
||||
</div>
|
||||
</DialogConfirmation>
|
||||
|
||||
+3
@@ -39,6 +39,7 @@
|
||||
onContinue?: () => void;
|
||||
onDelete: () => void;
|
||||
onEdit?: () => void;
|
||||
onForkConversation?: (options: { name: string; includeAttachments: boolean }) => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
onRegenerate: (modelOverride?: string) => void;
|
||||
onShowDeleteDialogChange: (show: boolean) => void;
|
||||
@@ -58,6 +59,7 @@
|
||||
onCopy,
|
||||
onDelete,
|
||||
onEdit,
|
||||
onForkConversation,
|
||||
onNavigateToSibling,
|
||||
onRegenerate,
|
||||
onShowDeleteDialogChange,
|
||||
@@ -345,6 +347,7 @@
|
||||
onContinue={currentConfig.enableContinueGeneration && !hasReasoningMarkers
|
||||
? onContinue
|
||||
: undefined}
|
||||
{onForkConversation}
|
||||
{onDelete}
|
||||
{onConfirmDelete}
|
||||
{onNavigateToSibling}
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
onEdit: () => void;
|
||||
onDelete: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
onForkConversation?: (options: { name: string; includeAttachments: boolean }) => void;
|
||||
onShowDeleteDialogChange: (show: boolean) => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
onCopy: () => void;
|
||||
@@ -35,6 +36,7 @@
|
||||
onEdit,
|
||||
onDelete,
|
||||
onConfirmDelete,
|
||||
onForkConversation,
|
||||
onShowDeleteDialogChange,
|
||||
onNavigateToSibling,
|
||||
onCopy
|
||||
@@ -114,6 +116,7 @@
|
||||
{onCopy}
|
||||
{onDelete}
|
||||
{onEdit}
|
||||
{onForkConversation}
|
||||
{onNavigateToSibling}
|
||||
{onShowDeleteDialogChange}
|
||||
{siblingInfo}
|
||||
|
||||
@@ -79,6 +79,13 @@
|
||||
onUserAction?.();
|
||||
await chatStore.continueAssistantMessage(message.id);
|
||||
refreshAllMessages();
|
||||
},
|
||||
|
||||
forkConversation: async (
|
||||
message: DatabaseMessage,
|
||||
options: { name: string; includeAttachments: boolean }
|
||||
) => {
|
||||
await conversationsStore.forkConversation(message.id, options);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
<script lang="ts">
|
||||
import { Settings } from '@lucide/svelte';
|
||||
import { DialogChatSettings } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { useSidebar } from '$lib/components/ui/sidebar';
|
||||
import { getChatSettingsDialogContext } from '$lib/contexts';
|
||||
|
||||
const sidebar = useSidebar();
|
||||
|
||||
let settingsOpen = $state(false);
|
||||
|
||||
function toggleSettings() {
|
||||
settingsOpen = true;
|
||||
}
|
||||
const chatSettingsDialog = getChatSettingsDialogContext();
|
||||
</script>
|
||||
|
||||
<header
|
||||
@@ -22,12 +17,10 @@
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon-lg"
|
||||
onclick={toggleSettings}
|
||||
onclick={() => chatSettingsDialog.open()}
|
||||
class="rounded-full backdrop-blur-lg"
|
||||
>
|
||||
<Settings class="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<DialogChatSettings open={settingsOpen} onOpenChange={(open) => (settingsOpen = open)} />
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { Trash2 } from '@lucide/svelte';
|
||||
import { Trash2, Pencil } from '@lucide/svelte';
|
||||
import { ChatSidebarConversationItem, DialogConfirmation } from '$lib/components/app';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
|
||||
import * as Sidebar from '$lib/components/ui/sidebar';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
import {
|
||||
conversationsStore,
|
||||
conversations,
|
||||
buildConversationTree
|
||||
} from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { getPreviewText } from '$lib/utils';
|
||||
import ChatSidebarActions from './ChatSidebarActions.svelte';
|
||||
@@ -18,6 +23,7 @@
|
||||
let isSearchModeActive = $state(false);
|
||||
let searchQuery = $state('');
|
||||
let showDeleteDialog = $state(false);
|
||||
let deleteWithForks = $state(false);
|
||||
let showEditDialog = $state(false);
|
||||
let selectedConversation = $state<DatabaseConversation | null>(null);
|
||||
let editedName = $state('');
|
||||
@@ -35,10 +41,30 @@
|
||||
return conversations();
|
||||
});
|
||||
|
||||
let conversationTree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
let selectedConversationHasDescendants = $derived.by(() => {
|
||||
if (!selectedConversation) return false;
|
||||
|
||||
const allConvs = conversations();
|
||||
const queue = [selectedConversation.id];
|
||||
|
||||
while (queue.length > 0) {
|
||||
const parentId = queue.pop()!;
|
||||
|
||||
for (const c of allConvs) {
|
||||
if (c.forkedFromConversationId === parentId) return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
async function handleDeleteConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (conversation) {
|
||||
selectedConversation = conversation;
|
||||
deleteWithForks = false;
|
||||
showDeleteDialog = true;
|
||||
}
|
||||
}
|
||||
@@ -54,11 +80,14 @@
|
||||
|
||||
function handleConfirmDelete() {
|
||||
if (selectedConversation) {
|
||||
const convId = selectedConversation.id;
|
||||
const withForks = deleteWithForks;
|
||||
showDeleteDialog = false;
|
||||
|
||||
setTimeout(() => {
|
||||
conversationsStore.deleteConversation(selectedConversation.id);
|
||||
selectedConversation = null;
|
||||
conversationsStore.deleteConversation(convId, {
|
||||
deleteWithForks: withForks
|
||||
});
|
||||
}, 100); // Wait for animation to finish
|
||||
}
|
||||
}
|
||||
@@ -110,7 +139,7 @@
|
||||
</script>
|
||||
|
||||
<ScrollArea class="h-[100vh]">
|
||||
<Sidebar.Header class=" top-0 z-10 gap-6 bg-sidebar/50 px-4 py-4 pb-2 backdrop-blur-lg md:sticky">
|
||||
<Sidebar.Header class=" top-0 z-10 gap-4 bg-sidebar/50 p-4 pb-2 backdrop-blur-lg md:sticky">
|
||||
<a href="#/" onclick={handleMobileSidebarItemClick}>
|
||||
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">llama.cpp</h1>
|
||||
</a>
|
||||
@@ -118,7 +147,7 @@
|
||||
<ChatSidebarActions {handleMobileSidebarItemClick} bind:isSearchModeActive bind:searchQuery />
|
||||
</Sidebar.Header>
|
||||
|
||||
<Sidebar.Group class="mt-4 space-y-2 p-0 px-4">
|
||||
<Sidebar.Group class="mt-2 space-y-2 p-0 px-4">
|
||||
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
|
||||
<Sidebar.GroupLabel>
|
||||
{isSearchModeActive ? 'Search results' : 'Conversations'}
|
||||
@@ -127,15 +156,17 @@
|
||||
|
||||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each filteredConversations as conversation (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1">
|
||||
{#each conversationTree as { conversation, depth } (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1 p-0">
|
||||
<ChatSidebarConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId
|
||||
}}
|
||||
{depth}
|
||||
{handleMobileSidebarItemClick}
|
||||
isActive={currentChatId === conversation.id}
|
||||
onSelect={selectConversation}
|
||||
@@ -146,7 +177,7 @@
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
|
||||
{#if filteredConversations.length === 0}
|
||||
{#if conversationTree.length === 0}
|
||||
<div class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{searchQuery.length > 0
|
||||
@@ -177,35 +208,40 @@
|
||||
showDeleteDialog = false;
|
||||
selectedConversation = null;
|
||||
}}
|
||||
/>
|
||||
>
|
||||
{#if selectedConversationHasDescendants}
|
||||
<div class="flex items-center gap-2 py-2">
|
||||
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
|
||||
|
||||
<AlertDialog.Root bind:open={showEditDialog}>
|
||||
<AlertDialog.Content>
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title>Edit Conversation Name</AlertDialog.Title>
|
||||
<AlertDialog.Description>
|
||||
<Input
|
||||
class="mt-4 text-foreground"
|
||||
onkeydown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault();
|
||||
handleConfirmEdit();
|
||||
}
|
||||
}}
|
||||
placeholder="Enter a new name"
|
||||
type="text"
|
||||
bind:value={editedName}
|
||||
/>
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
<AlertDialog.Footer>
|
||||
<AlertDialog.Cancel
|
||||
onclick={() => {
|
||||
showEditDialog = false;
|
||||
selectedConversation = null;
|
||||
}}>Cancel</AlertDialog.Cancel
|
||||
>
|
||||
<AlertDialog.Action onclick={handleConfirmEdit}>Save</AlertDialog.Action>
|
||||
</AlertDialog.Footer>
|
||||
</AlertDialog.Content>
|
||||
</AlertDialog.Root>
|
||||
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
|
||||
</div>
|
||||
{/if}
|
||||
</DialogConfirmation>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showEditDialog}
|
||||
title="Edit Conversation Name"
|
||||
description=""
|
||||
confirmText="Save"
|
||||
cancelText="Cancel"
|
||||
icon={Pencil}
|
||||
onConfirm={handleConfirmEdit}
|
||||
onCancel={() => {
|
||||
showEditDialog = false;
|
||||
selectedConversation = null;
|
||||
}}
|
||||
onKeydown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault();
|
||||
e.stopImmediatePropagation();
|
||||
handleConfirmEdit();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Input
|
||||
class="text-foreground"
|
||||
placeholder="Enter a new name"
|
||||
type="text"
|
||||
bind:value={editedName}
|
||||
/>
|
||||
</DialogConfirmation>
|
||||
|
||||
+25
-4
@@ -3,6 +3,9 @@
|
||||
import { KeyboardShortcutInfo } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { McpLogo } from '$lib/components/app';
|
||||
import { SETTINGS_SECTION_TITLES } from '$lib/constants';
|
||||
import { getChatSettingsDialogContext } from '$lib/contexts';
|
||||
|
||||
interface Props {
|
||||
handleMobileSidebarItemClick: () => void;
|
||||
@@ -18,6 +21,8 @@
|
||||
|
||||
let searchInput: HTMLInputElement | null = $state(null);
|
||||
|
||||
const chatSettingsDialog = getChatSettingsDialogContext();
|
||||
|
||||
function handleSearchModeDeactivate() {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
@@ -30,7 +35,7 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="space-y-0.5">
|
||||
<div class="my-1 space-y-1">
|
||||
{#if isSearchModeActive}
|
||||
<div class="relative">
|
||||
<Search class="absolute top-2.5 left-2 h-4 w-4 text-muted-foreground" />
|
||||
@@ -50,13 +55,14 @@
|
||||
</div>
|
||||
{:else}
|
||||
<Button
|
||||
class="w-full justify-between hover:[&>kbd]:opacity-100"
|
||||
class="w-full justify-between backdrop-blur-none! hover:[&>kbd]:opacity-100"
|
||||
href="?new_chat=true#/"
|
||||
onclick={handleMobileSidebarItemClick}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<SquarePen class="h-4 w-4" />
|
||||
|
||||
New chat
|
||||
</div>
|
||||
|
||||
@@ -64,7 +70,7 @@
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
class="w-full justify-between hover:[&>kbd]:opacity-100"
|
||||
class="w-full justify-between backdrop-blur-none! hover:[&>kbd]:opacity-100"
|
||||
onclick={() => {
|
||||
isSearchModeActive = true;
|
||||
}}
|
||||
@@ -72,10 +78,25 @@
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<Search class="h-4 w-4" />
|
||||
Search conversations
|
||||
|
||||
Search
|
||||
</div>
|
||||
|
||||
<KeyboardShortcutInfo keys={['cmd', 'k']} />
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
class="w-full justify-between backdrop-blur-none! hover:[&>kbd]:opacity-100"
|
||||
onclick={() => {
|
||||
chatSettingsDialog.open(SETTINGS_SECTION_TITLES.MCP);
|
||||
}}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<McpLogo class="h-4 w-4" />
|
||||
|
||||
MCP Servers
|
||||
</div>
|
||||
</Button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
+36
-5
@@ -1,13 +1,23 @@
|
||||
<script lang="ts">
|
||||
import { Trash2, Pencil, MoreHorizontal, Download, Loader2, Square } from '@lucide/svelte';
|
||||
import {
|
||||
Trash2,
|
||||
Pencil,
|
||||
MoreHorizontal,
|
||||
Download,
|
||||
Loader2,
|
||||
Square,
|
||||
GitBranch
|
||||
} from '@lucide/svelte';
|
||||
import { DropdownMenuActions } from '$lib/components/app';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { FORK_TREE_DEPTH_PADDING } from '$lib/constants';
|
||||
import { getAllLoadingChats } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
isActive?: boolean;
|
||||
depth?: number;
|
||||
conversation: DatabaseConversation;
|
||||
handleMobileSidebarItemClick?: () => void;
|
||||
onDelete?: (id: string) => void;
|
||||
@@ -23,7 +33,8 @@
|
||||
onEdit,
|
||||
onSelect,
|
||||
onStop,
|
||||
isActive = false
|
||||
isActive = false,
|
||||
depth = 0
|
||||
}: Props = $props();
|
||||
|
||||
let renderActionsDropdown = $state(false);
|
||||
@@ -88,14 +99,34 @@
|
||||
|
||||
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
|
||||
<button
|
||||
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg px-3 py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
|
||||
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
|
||||
? 'bg-foreground/5 text-accent-foreground'
|
||||
: ''}"
|
||||
: ''} px-3"
|
||||
onclick={handleSelect}
|
||||
onmouseover={handleMouseOver}
|
||||
onmouseleave={handleMouseLeave}
|
||||
>
|
||||
<div class="flex min-w-0 flex-1 items-center gap-2">
|
||||
<div
|
||||
class="flex min-w-0 flex-1 items-center gap-2"
|
||||
style:padding-left="{depth * FORK_TREE_DEPTH_PADDING}px"
|
||||
>
|
||||
{#if depth > 0}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<a
|
||||
href="#/chat/{conversation.forkedFromConversationId}"
|
||||
class="flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
>
|
||||
<GitBranch class="h-3.5 w-3.5" />
|
||||
</a>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>See parent conversation</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{/if}
|
||||
|
||||
{#if isLoading}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<script lang="ts">
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import type { Component } from 'svelte';
|
||||
import type { Component, Snippet } from 'svelte';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
@@ -14,6 +14,7 @@
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
onKeydown?: (event: KeyboardEvent) => void;
|
||||
children?: Snippet;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -26,7 +27,8 @@
|
||||
icon,
|
||||
onConfirm,
|
||||
onCancel,
|
||||
onKeydown
|
||||
onKeydown,
|
||||
children
|
||||
}: Props = $props();
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
@@ -60,6 +62,10 @@
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
|
||||
{#if children}
|
||||
{@render children()}
|
||||
{/if}
|
||||
|
||||
<AlertDialog.Footer>
|
||||
<AlertDialog.Cancel onclick={onCancel}>{cancelText}</AlertDialog.Cancel>
|
||||
<AlertDialog.Action
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
showRaw = undefined,
|
||||
aliases,
|
||||
tags,
|
||||
class: className = ''
|
||||
class: className = '',
|
||||
...rest
|
||||
}: Props = $props();
|
||||
|
||||
const badgeClass =
|
||||
@@ -36,9 +37,9 @@
|
||||
</script>
|
||||
|
||||
{#if resolvedShowRaw}
|
||||
<TruncatedText class="font-medium {className}" showTooltip={false} text={modelId} />
|
||||
<TruncatedText class="font-medium {className}" showTooltip={false} text={modelId} {...rest} />
|
||||
{:else}
|
||||
<span class="flex min-w-0 flex-wrap items-center gap-1 {className}">
|
||||
<span class="flex min-w-0 flex-wrap items-center gap-1 {className}" {...rest}>
|
||||
<span class="min-w-0 truncate font-medium">
|
||||
{#if showOrgName && parsed.orgName && !(aliases && aliases.length > 0)}{parsed.orgName}/{/if}{displayName}
|
||||
</span>
|
||||
|
||||
@@ -271,50 +271,49 @@
|
||||
{#if isRouter}
|
||||
<DropdownMenu.Root bind:open={isOpen} onOpenChange={handleOpenChange}>
|
||||
<DropdownMenu.Trigger
|
||||
disabled={disabled || updating}
|
||||
onclick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}}
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
class={cn(
|
||||
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-muted-foreground/10 px-1.5 py-1 text-xs transition hover:text-foreground focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60`,
|
||||
!isCurrentModelInCache
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
class={cn(
|
||||
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-muted-foreground/10 px-1.5 py-1 text-xs transition hover:text-foreground focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60`,
|
||||
!isCurrentModelInCache
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
? 'text-foreground'
|
||||
: isHighlightedCurrentModelActive
|
||||
? 'text-foreground'
|
||||
: isHighlightedCurrentModelActive
|
||||
? 'text-foreground'
|
||||
: 'text-muted-foreground',
|
||||
isOpen ? 'text-foreground' : ''
|
||||
)}
|
||||
style="max-width: min(calc(100cqw - 9rem), 20rem)"
|
||||
disabled={disabled || updating}
|
||||
>
|
||||
<Package class="h-3.5 w-3.5" />
|
||||
: 'text-muted-foreground',
|
||||
isOpen ? 'text-foreground' : ''
|
||||
)}
|
||||
style="max-width: min(calc(100cqw - 9rem), 20rem)"
|
||||
disabled={disabled || updating}
|
||||
>
|
||||
<Package class="h-3.5 w-3.5" />
|
||||
|
||||
{#if selectedOption}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="min-w-0 overflow-hidden">
|
||||
<ModelId modelId={selectedOption.model} class="min-w-0" showOrgName />
|
||||
</Tooltip.Trigger>
|
||||
{#if selectedOption}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<!-- prevent another nested button element -->
|
||||
{#snippet child({ props })}
|
||||
<ModelId
|
||||
modelId={selectedOption.model}
|
||||
class="min-w-0 overflow-hidden"
|
||||
showOrgName
|
||||
{...props}
|
||||
/>
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p class="font-mono">{selectedOption.model}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<span class="min-w-0 font-medium">Select model</span>
|
||||
{/if}
|
||||
<Tooltip.Content>
|
||||
<p class="font-mono">{selectedOption.model}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<span class="min-w-0 font-medium">Select model</span>
|
||||
{/if}
|
||||
|
||||
{#if updating || isLoadingModel}
|
||||
<Loader2 class="h-3 w-3.5 animate-spin" />
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5" />
|
||||
{/if}
|
||||
</button>
|
||||
{#if updating || isLoadingModel}
|
||||
<Loader2 class="h-3 w-3.5 animate-spin" />
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5" />
|
||||
{/if}
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content
|
||||
@@ -407,8 +406,16 @@
|
||||
|
||||
{#if selectedOption}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="min-w-0 overflow-hidden">
|
||||
<ModelId modelId={selectedOption.model} class="min-w-0" showOrgName />
|
||||
<Tooltip.Trigger>
|
||||
<!-- prevent another nested button element -->
|
||||
{#snippet child({ props })}
|
||||
<ModelId
|
||||
modelId={selectedOption.model}
|
||||
class="min-w-0 overflow-hidden"
|
||||
showOrgName
|
||||
{...props}
|
||||
/>
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
export const CONTEXT_KEY_MESSAGE_EDIT = 'chat-message-edit';
|
||||
export const CONTEXT_KEY_CHAT_ACTIONS = 'chat-actions';
|
||||
export const CONTEXT_KEY_CHAT_SETTINGS_DIALOG = 'chat-settings-dialog';
|
||||
@@ -10,6 +10,7 @@ export * from './cache';
|
||||
export * from './chat-form';
|
||||
export * from './code-blocks';
|
||||
export * from './code';
|
||||
export * from './context-keys';
|
||||
export * from './css-classes';
|
||||
export * from './favicon';
|
||||
export * from './floating-ui-constraints';
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
export const FORK_TREE_DEPTH_PADDING = 8;
|
||||
export const SYSTEM_MESSAGE_PLACEHOLDER = 'System message';
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { getContext, setContext } from 'svelte';
|
||||
import { CONTEXT_KEY_CHAT_ACTIONS } from '$lib/constants';
|
||||
|
||||
export interface ChatActionsContext {
|
||||
copy: (message: DatabaseMessage) => void;
|
||||
@@ -21,9 +22,13 @@ export interface ChatActionsContext {
|
||||
) => void;
|
||||
regenerateWithBranching: (message: DatabaseMessage, modelOverride?: string) => void;
|
||||
continueAssistantMessage: (message: DatabaseMessage) => void;
|
||||
forkConversation: (
|
||||
message: DatabaseMessage,
|
||||
options: { name: string; includeAttachments: boolean }
|
||||
) => void;
|
||||
}
|
||||
|
||||
const CHAT_ACTIONS_KEY = Symbol.for('chat-actions');
|
||||
const CHAT_ACTIONS_KEY = Symbol.for(CONTEXT_KEY_CHAT_ACTIONS);
|
||||
|
||||
export function setChatActionsContext(ctx: ChatActionsContext): ChatActionsContext {
|
||||
return setContext(CHAT_ACTIONS_KEY, ctx);
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import { getContext, setContext } from 'svelte';
|
||||
import type { SettingsSectionTitle } from '$lib/constants';
|
||||
import { CONTEXT_KEY_CHAT_SETTINGS_DIALOG } from '$lib/constants';
|
||||
|
||||
export interface ChatSettingsDialogContext {
|
||||
open: (initialSection?: SettingsSectionTitle) => void;
|
||||
}
|
||||
|
||||
const CHAT_SETTINGS_DIALOG_KEY = Symbol.for(CONTEXT_KEY_CHAT_SETTINGS_DIALOG);
|
||||
|
||||
export function setChatSettingsDialogContext(
|
||||
ctx: ChatSettingsDialogContext
|
||||
): ChatSettingsDialogContext {
|
||||
return setContext(CHAT_SETTINGS_DIALOG_KEY, ctx);
|
||||
}
|
||||
|
||||
export function getChatSettingsDialogContext(): ChatSettingsDialogContext {
|
||||
return getContext(CHAT_SETTINGS_DIALOG_KEY);
|
||||
}
|
||||
@@ -11,3 +11,9 @@ export {
|
||||
setChatActionsContext,
|
||||
type ChatActionsContext
|
||||
} from './chat-actions.context';
|
||||
|
||||
export {
|
||||
getChatSettingsDialogContext,
|
||||
setChatSettingsDialogContext,
|
||||
type ChatSettingsDialogContext
|
||||
} from './chat-settings-dialog.context';
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { getContext, setContext } from 'svelte';
|
||||
import { CONTEXT_KEY_MESSAGE_EDIT } from '$lib/constants';
|
||||
|
||||
export interface MessageEditState {
|
||||
readonly isEditing: boolean;
|
||||
@@ -22,7 +23,7 @@ export interface MessageEditActions {
|
||||
|
||||
export type MessageEditContext = MessageEditState & MessageEditActions;
|
||||
|
||||
const MESSAGE_EDIT_KEY = Symbol.for('chat-message-edit');
|
||||
const MESSAGE_EDIT_KEY = Symbol.for(CONTEXT_KEY_MESSAGE_EDIT);
|
||||
|
||||
/**
|
||||
* Sets the message edit context. Call this in the parent component (ChatMessage.svelte).
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import Dexie, { type EntityTable } from 'dexie';
|
||||
import { findDescendantMessages, uuid } from '$lib/utils';
|
||||
import { findDescendantMessages, uuid, filterByLeafNodeId } from '$lib/utils';
|
||||
import type { McpServerOverride } from '$lib/types/database';
|
||||
|
||||
class LlamacppDatabase extends Dexie {
|
||||
conversations!: EntityTable<DatabaseConversation, string>;
|
||||
@@ -173,8 +174,47 @@ export class DatabaseService {
|
||||
*
|
||||
* @param id - Conversation ID
|
||||
*/
|
||||
static async deleteConversation(id: string): Promise<void> {
|
||||
static async deleteConversation(
|
||||
id: string,
|
||||
options?: { deleteWithForks?: boolean }
|
||||
): Promise<void> {
|
||||
await db.transaction('rw', [db.conversations, db.messages], async () => {
|
||||
if (options?.deleteWithForks) {
|
||||
// Recursively collect all descendant IDs
|
||||
const idsToDelete: string[] = [];
|
||||
const queue = [id];
|
||||
|
||||
while (queue.length > 0) {
|
||||
const parentId = queue.pop()!;
|
||||
const children = await db.conversations
|
||||
.filter((c) => c.forkedFromConversationId === parentId)
|
||||
.toArray();
|
||||
|
||||
for (const child of children) {
|
||||
idsToDelete.push(child.id);
|
||||
queue.push(child.id);
|
||||
}
|
||||
}
|
||||
|
||||
for (const forkId of idsToDelete) {
|
||||
await db.conversations.delete(forkId);
|
||||
await db.messages.where('convId').equals(forkId).delete();
|
||||
}
|
||||
} else {
|
||||
// Reparent direct children to deleted conv's parent
|
||||
const conv = await db.conversations.get(id);
|
||||
const newParent = conv?.forkedFromConversationId;
|
||||
const directChildren = await db.conversations
|
||||
.filter((c) => c.forkedFromConversationId === id)
|
||||
.toArray();
|
||||
|
||||
for (const child of directChildren) {
|
||||
await db.conversations.update(child.id, {
|
||||
forkedFromConversationId: newParent ?? undefined
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
await db.conversations.delete(id);
|
||||
await db.messages.where('convId').equals(id).delete();
|
||||
});
|
||||
@@ -364,4 +404,88 @@ export class DatabaseService {
|
||||
return { imported: importedCount, skipped: skippedCount };
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
* Forking
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Forks a conversation at a specific message, creating a new conversation
|
||||
* containing all messages from the root up to (and including) the target message.
|
||||
*
|
||||
* @param sourceConvId - The source conversation ID
|
||||
* @param atMessageId - The message ID to fork at (the new conversation ends here)
|
||||
* @param options - Fork options (name and whether to include attachments)
|
||||
* @returns The newly created conversation
|
||||
*/
|
||||
static async forkConversation(
|
||||
sourceConvId: string,
|
||||
atMessageId: string,
|
||||
options: { name: string; includeAttachments: boolean }
|
||||
): Promise<DatabaseConversation> {
|
||||
return await db.transaction('rw', [db.conversations, db.messages], async () => {
|
||||
const sourceConv = await db.conversations.get(sourceConvId);
|
||||
if (!sourceConv) {
|
||||
throw new Error(`Source conversation ${sourceConvId} not found`);
|
||||
}
|
||||
|
||||
const allMessages = await db.messages.where('convId').equals(sourceConvId).toArray();
|
||||
|
||||
const pathMessages = filterByLeafNodeId(allMessages, atMessageId, true) as DatabaseMessage[];
|
||||
if (pathMessages.length === 0) {
|
||||
throw new Error(`Could not resolve message path to ${atMessageId}`);
|
||||
}
|
||||
|
||||
const idMap = new Map<string, string>();
|
||||
|
||||
for (const msg of pathMessages) {
|
||||
idMap.set(msg.id, uuid());
|
||||
}
|
||||
|
||||
const newConvId = uuid();
|
||||
const clonedMessages: DatabaseMessage[] = pathMessages.map((msg) => {
|
||||
const newId = idMap.get(msg.id)!;
|
||||
const newParent = msg.parent ? (idMap.get(msg.parent) ?? null) : null;
|
||||
const newChildren = msg.children
|
||||
.filter((childId: string) => idMap.has(childId))
|
||||
.map((childId: string) => idMap.get(childId)!);
|
||||
|
||||
return {
|
||||
...msg,
|
||||
id: newId,
|
||||
convId: newConvId,
|
||||
parent: newParent,
|
||||
children: newChildren,
|
||||
extra: options.includeAttachments ? msg.extra : undefined
|
||||
};
|
||||
});
|
||||
|
||||
const lastClonedMessage = clonedMessages[clonedMessages.length - 1];
|
||||
const newConv: DatabaseConversation = {
|
||||
id: newConvId,
|
||||
name: options.name,
|
||||
lastModified: Date.now(),
|
||||
currNode: lastClonedMessage.id,
|
||||
forkedFromConversationId: sourceConvId,
|
||||
mcpServerOverrides: sourceConv.mcpServerOverrides
|
||||
? sourceConv.mcpServerOverrides.map((o: McpServerOverride) => ({
|
||||
serverId: o.serverId,
|
||||
enabled: o.enabled
|
||||
}))
|
||||
: undefined
|
||||
};
|
||||
|
||||
await db.conversations.add(newConv);
|
||||
|
||||
for (const msg of clonedMessages) {
|
||||
await db.messages.add(msg);
|
||||
}
|
||||
|
||||
return newConv;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,13 @@ import type {
|
||||
MCPResourceContent,
|
||||
MCPReadResourceResult
|
||||
} from '$lib/types';
|
||||
import { buildProxiedUrl, throwIfAborted, isAbortError, createBase64DataUrl } from '$lib/utils';
|
||||
import {
|
||||
buildProxiedUrl,
|
||||
buildProxiedHeaders,
|
||||
throwIfAborted,
|
||||
isAbortError,
|
||||
createBase64DataUrl
|
||||
} from '$lib/utils';
|
||||
|
||||
interface ToolResultContentItem {
|
||||
type: string;
|
||||
@@ -118,7 +124,7 @@ export class MCPService {
|
||||
const requestInit: RequestInit = {};
|
||||
|
||||
if (config.headers) {
|
||||
requestInit.headers = config.headers;
|
||||
requestInit.headers = buildProxiedHeaders(config.headers);
|
||||
}
|
||||
|
||||
if (config.credentials) {
|
||||
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
filterByLeafNodeId,
|
||||
findDescendantMessages,
|
||||
findLeafNode,
|
||||
findMessageById,
|
||||
isAbortError
|
||||
} from '$lib/utils';
|
||||
import {
|
||||
@@ -416,7 +417,7 @@ class ChatStore {
|
||||
if (!activeConv) return false;
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const systemMessage = allMessages.find((m) => m.id === messageId);
|
||||
const systemMessage = findMessageById(allMessages, messageId);
|
||||
if (!systemMessage || systemMessage.role !== MessageRole.SYSTEM) return false;
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
if (!rootMessage) return false;
|
||||
@@ -878,7 +879,7 @@ class ChatStore {
|
||||
const msg = conversationsStore.activeMessages[idx];
|
||||
if (msg.role !== MessageRole.ASSISTANT) return;
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const parentMessage = allMessages.find((m) => m.id === msg.parent);
|
||||
const parentMessage = findMessageById(allMessages, msg.parent);
|
||||
if (!parentMessage) return;
|
||||
this.setChatLoading(activeConv.id, true);
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
@@ -928,7 +929,7 @@ class ChatStore {
|
||||
if (!activeConv)
|
||||
return { totalCount: 0, userMessages: 0, assistantMessages: 0, messageTypes: [] };
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const messageToDelete = allMessages.find((m) => m.id === messageId);
|
||||
const messageToDelete = findMessageById(allMessages, messageId);
|
||||
|
||||
// For system messages, don't count descendants as they will be preserved (reparented to root)
|
||||
if (messageToDelete?.role === MessageRole.SYSTEM) {
|
||||
@@ -975,7 +976,7 @@ class ChatStore {
|
||||
if (!activeConv) return;
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const messageToDelete = allMessages.find((m) => m.id === messageId);
|
||||
const messageToDelete = findMessageById(allMessages, messageId);
|
||||
|
||||
if (!messageToDelete) return;
|
||||
|
||||
@@ -1024,7 +1025,7 @@ class ChatStore {
|
||||
this.clearChatStreaming(activeConv.id);
|
||||
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const dbMessage = allMessages.find((m) => m.id === messageId);
|
||||
const dbMessage = findMessageById(allMessages, messageId);
|
||||
|
||||
if (!dbMessage) {
|
||||
this.setChatLoading(activeConv.id, false);
|
||||
@@ -1265,35 +1266,56 @@ class ChatStore {
|
||||
let result = this.getMessageByIdWithRole(messageId, MessageRole.USER);
|
||||
if (!result) result = this.getMessageByIdWithRole(messageId, MessageRole.SYSTEM);
|
||||
if (!result) return;
|
||||
const { message: msg } = result;
|
||||
const { message: msg, index: idx } = result;
|
||||
try {
|
||||
const allMessages = await conversationsStore.getConversationMessages(activeConv.id);
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
const isFirstUserMessage =
|
||||
msg.role === MessageRole.USER && rootMessage && msg.parent === rootMessage.id;
|
||||
const parentId = msg.parent || rootMessage?.id;
|
||||
if (!parentId) return;
|
||||
const extrasToUse =
|
||||
newExtras !== undefined
|
||||
? JSON.parse(JSON.stringify(newExtras))
|
||||
: msg.extra
|
||||
? JSON.parse(JSON.stringify(msg.extra))
|
||||
: undefined;
|
||||
const newMessage = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId: msg.convId,
|
||||
type: msg.type,
|
||||
timestamp: Date.now(),
|
||||
role: msg.role,
|
||||
|
||||
let messageIdForResponse: string;
|
||||
|
||||
const dbMsg = findMessageById(allMessages, msg.id);
|
||||
const hasChildren = dbMsg ? dbMsg.children.length > 0 : msg.children.length > 0;
|
||||
|
||||
if (!hasChildren) {
|
||||
// No responses after this message — update in place instead of branching
|
||||
const updates: Partial<DatabaseMessage> = {
|
||||
content: newContent,
|
||||
toolCalls: msg.toolCalls || '',
|
||||
children: [],
|
||||
extra: extrasToUse,
|
||||
model: msg.model
|
||||
},
|
||||
parentId
|
||||
);
|
||||
await conversationsStore.updateCurrentNode(newMessage.id);
|
||||
timestamp: Date.now(),
|
||||
extra: extrasToUse
|
||||
};
|
||||
await DatabaseService.updateMessage(msg.id, updates);
|
||||
conversationsStore.updateMessageAtIndex(idx, updates);
|
||||
messageIdForResponse = msg.id;
|
||||
} else {
|
||||
// Has children — create a new branch as sibling
|
||||
const parentId = msg.parent || rootMessage?.id;
|
||||
if (!parentId) return;
|
||||
const newMessage = await DatabaseService.createMessageBranch(
|
||||
{
|
||||
convId: msg.convId,
|
||||
type: msg.type,
|
||||
timestamp: Date.now(),
|
||||
role: msg.role,
|
||||
content: newContent,
|
||||
toolCalls: msg.toolCalls || '',
|
||||
children: [],
|
||||
extra: extrasToUse,
|
||||
model: msg.model
|
||||
},
|
||||
parentId
|
||||
);
|
||||
await conversationsStore.updateCurrentNode(newMessage.id);
|
||||
messageIdForResponse = newMessage.id;
|
||||
}
|
||||
|
||||
conversationsStore.updateConversationTimestamp();
|
||||
if (isFirstUserMessage && newContent.trim())
|
||||
await conversationsStore.updateConversationTitleWithConfirmation(
|
||||
@@ -1301,7 +1323,8 @@ class ChatStore {
|
||||
newContent.trim()
|
||||
);
|
||||
await conversationsStore.refreshActiveMessages();
|
||||
if (msg.role === MessageRole.USER) await this.generateResponseForMessage(newMessage.id);
|
||||
if (msg.role === MessageRole.USER)
|
||||
await this.generateResponseForMessage(messageIdForResponse);
|
||||
} catch (error) {
|
||||
console.error('Failed to edit message with branching:', error);
|
||||
}
|
||||
|
||||
@@ -39,6 +39,12 @@ import {
|
||||
MULTIPLE_UNDERSCORE_REGEX,
|
||||
MCP_DEFAULT_ENABLED_LOCALSTORAGE_KEY
|
||||
} from '$lib/constants';
|
||||
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
|
||||
|
||||
export interface ConversationTreeItem {
|
||||
conversation: DatabaseConversation;
|
||||
depth: number;
|
||||
}
|
||||
|
||||
class ConversationsStore {
|
||||
/**
|
||||
@@ -300,15 +306,45 @@ class ConversationsStore {
|
||||
* Deletes a conversation and all its messages
|
||||
* @param convId - The conversation ID to delete
|
||||
*/
|
||||
async deleteConversation(convId: string): Promise<void> {
|
||||
async deleteConversation(convId: string, options?: { deleteWithForks?: boolean }): Promise<void> {
|
||||
try {
|
||||
await DatabaseService.deleteConversation(convId);
|
||||
await DatabaseService.deleteConversation(convId, options);
|
||||
|
||||
this.conversations = this.conversations.filter((c) => c.id !== convId);
|
||||
if (options?.deleteWithForks) {
|
||||
// Collect all descendants recursively
|
||||
const idsToRemove = new SvelteSet([convId]);
|
||||
const queue = [convId];
|
||||
while (queue.length > 0) {
|
||||
const parentId = queue.pop()!;
|
||||
for (const c of this.conversations) {
|
||||
if (c.forkedFromConversationId === parentId && !idsToRemove.has(c.id)) {
|
||||
idsToRemove.add(c.id);
|
||||
queue.push(c.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
this.conversations = this.conversations.filter((c) => !idsToRemove.has(c.id));
|
||||
|
||||
if (this.activeConversation?.id === convId) {
|
||||
this.clearActiveConversation();
|
||||
await goto(`?new_chat=true#/`);
|
||||
if (this.activeConversation && idsToRemove.has(this.activeConversation.id)) {
|
||||
this.clearActiveConversation();
|
||||
await goto(`?new_chat=true#/`);
|
||||
}
|
||||
} else {
|
||||
// Reparent direct children to deleted conv's parent (or promote to top-level)
|
||||
const deletedConv = this.conversations.find((c) => c.id === convId);
|
||||
const newParent = deletedConv?.forkedFromConversationId;
|
||||
this.conversations = this.conversations
|
||||
.filter((c) => c.id !== convId)
|
||||
.map((c) =>
|
||||
c.forkedFromConversationId === convId
|
||||
? { ...c, forkedFromConversationId: newParent }
|
||||
: c
|
||||
);
|
||||
|
||||
if (this.activeConversation?.id === convId) {
|
||||
this.clearActiveConversation();
|
||||
await goto(`?new_chat=true#/`);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to delete conversation:', error);
|
||||
@@ -658,6 +694,42 @@ class ConversationsStore {
|
||||
this.saveMcpDefaults();
|
||||
}
|
||||
|
||||
/**
|
||||
* Forks a conversation at a specific message, creating a new conversation
|
||||
* containing messages from root up to the target message, then navigates to it.
|
||||
*
|
||||
* @param messageId - The message ID to fork at
|
||||
* @param options - Fork options (name and whether to include attachments)
|
||||
* @returns The new conversation ID, or null if fork failed
|
||||
*/
|
||||
async forkConversation(
|
||||
messageId: string,
|
||||
options: { name: string; includeAttachments: boolean }
|
||||
): Promise<string | null> {
|
||||
if (!this.activeConversation) return null;
|
||||
|
||||
try {
|
||||
const newConv = await DatabaseService.forkConversation(
|
||||
this.activeConversation.id,
|
||||
messageId,
|
||||
options
|
||||
);
|
||||
|
||||
this.conversations = [newConv, ...this.conversations];
|
||||
|
||||
await goto(`#/chat/${newConv.id}`);
|
||||
|
||||
toast.success('Conversation forked');
|
||||
|
||||
return newConv.id;
|
||||
} catch (error) {
|
||||
console.error('Failed to fork conversation:', error);
|
||||
toast.error('Failed to fork conversation');
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
@@ -830,3 +902,53 @@ export const conversations = () => conversationsStore.conversations;
|
||||
export const activeConversation = () => conversationsStore.activeConversation;
|
||||
export const activeMessages = () => conversationsStore.activeMessages;
|
||||
export const isConversationsInitialized = () => conversationsStore.isInitialized;
|
||||
|
||||
/**
|
||||
* Builds a flat tree of conversations with depth levels for nested forks.
|
||||
* Accepts a pre-filtered list so search filtering stays in the component.
|
||||
*/
|
||||
export function buildConversationTree(convs: DatabaseConversation[]): ConversationTreeItem[] {
|
||||
const childrenByParent = new SvelteMap<string, DatabaseConversation[]>();
|
||||
const forkIds = new SvelteSet<string>();
|
||||
|
||||
for (const conv of convs) {
|
||||
if (conv.forkedFromConversationId) {
|
||||
forkIds.add(conv.id);
|
||||
|
||||
const siblings = childrenByParent.get(conv.forkedFromConversationId) || [];
|
||||
|
||||
siblings.push(conv);
|
||||
childrenByParent.set(conv.forkedFromConversationId, siblings);
|
||||
}
|
||||
}
|
||||
|
||||
const result: ConversationTreeItem[] = [];
|
||||
const visited = new SvelteSet<string>();
|
||||
|
||||
function walk(conv: DatabaseConversation, depth: number) {
|
||||
visited.add(conv.id);
|
||||
result.push({ conversation: conv, depth });
|
||||
|
||||
const children = childrenByParent.get(conv.id);
|
||||
if (children) {
|
||||
children.sort((a, b) => b.lastModified - a.lastModified);
|
||||
|
||||
for (const child of children) {
|
||||
walk(child, depth + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const roots = convs.filter((c) => !forkIds.has(c.id));
|
||||
for (const root of roots) {
|
||||
walk(root, 0);
|
||||
}
|
||||
|
||||
for (const conv of convs) {
|
||||
if (!visited.has(conv.id)) {
|
||||
walk(conv, 1);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ export interface DatabaseConversation {
|
||||
lastModified: number;
|
||||
name: string;
|
||||
mcpServerOverrides?: McpServerOverride[];
|
||||
forkedFromConversationId?: string;
|
||||
}
|
||||
|
||||
export interface DatabaseMessageExtraAudioFile {
|
||||
|
||||
@@ -17,6 +17,17 @@
|
||||
|
||||
import { MessageRole } from '$lib/enums';
|
||||
|
||||
/**
|
||||
* Finds a message by its ID in the given messages array.
|
||||
*/
|
||||
export function findMessageById(
|
||||
messages: readonly DatabaseMessage[],
|
||||
id: string | null | undefined
|
||||
): DatabaseMessage | undefined {
|
||||
if (!id) return undefined;
|
||||
return messages.find((m) => m.id === id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters messages to get the conversation path from root to a specific leaf node.
|
||||
* If the leafNodeId doesn't exist, returns the path with the latest timestamp.
|
||||
|
||||
@@ -19,6 +19,21 @@ export function buildProxiedUrl(targetUrl: string): URL {
|
||||
return proxyUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap original headers for proxying through the CORS proxy. This avoids issues with duplicated llama.cpp-specific and target headers when using the CORS proxy.
|
||||
* @param headers - The original headers to be proxied to target
|
||||
* @returns List of "wrapped" headers to be sent to the CORS proxy
|
||||
*/
|
||||
export function buildProxiedHeaders(headers: Record<string, string>): Record<string, string> {
|
||||
const proxiedHeaders: Record<string, string> = {};
|
||||
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
proxiedHeaders[`X-Proxy-Header-${key}`] = value;
|
||||
}
|
||||
|
||||
return proxiedHeaders;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a proxied URL string for use in fetch requests.
|
||||
* @param targetUrl - The original URL to proxy
|
||||
|
||||
@@ -22,6 +22,7 @@ export { default as autoResizeTextarea } from './autoresize-textarea';
|
||||
// Branching utilities
|
||||
export {
|
||||
filterByLeafNodeId,
|
||||
findMessageById,
|
||||
findLeafNode,
|
||||
findDescendantMessages,
|
||||
getMessageSiblings,
|
||||
@@ -38,7 +39,7 @@ export { highlightCode, detectIncompleteCodeBlock, type IncompleteCodeBlock } fr
|
||||
export { setConfigValue, getConfigValue, configToParameterRecord } from './config-helpers';
|
||||
|
||||
// CORS Proxy
|
||||
export { buildProxiedUrl, getProxiedUrlString } from './cors-proxy';
|
||||
export { buildProxiedUrl, getProxiedUrlString, buildProxiedHeaders } from './cors-proxy';
|
||||
|
||||
// Conversation utilities
|
||||
export { createMessageCountMap, getMessageCount } from './conversation-utils';
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
import { browser } from '$app/environment';
|
||||
import { page } from '$app/state';
|
||||
import { untrack } from 'svelte';
|
||||
import { ChatSidebar, DialogConversationTitleUpdate } from '$lib/components/app';
|
||||
import {
|
||||
ChatSidebar,
|
||||
DialogConversationTitleUpdate,
|
||||
DialogChatSettings
|
||||
} from '$lib/components/app';
|
||||
import { isLoading } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, activeMessages } from '$lib/stores/conversations.svelte';
|
||||
import * as Sidebar from '$lib/components/ui/sidebar/index.js';
|
||||
@@ -17,8 +21,10 @@
|
||||
import { modelsStore } from '$lib/stores/models.svelte';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { TOOLTIP_DELAY_DURATION } from '$lib/constants';
|
||||
import type { SettingsSectionTitle } from '$lib/constants';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
import { IsMobile } from '$lib/hooks/is-mobile.svelte';
|
||||
import { setChatSettingsDialogContext } from '$lib/contexts';
|
||||
|
||||
let { children } = $props();
|
||||
|
||||
@@ -42,6 +48,16 @@
|
||||
let titleUpdateNewTitle = $state('');
|
||||
let titleUpdateResolve: ((value: boolean) => void) | null = null;
|
||||
|
||||
let chatSettingsDialogOpen = $state(false);
|
||||
let chatSettingsDialogInitialSection = $state<SettingsSectionTitle | undefined>(undefined);
|
||||
|
||||
setChatSettingsDialogContext({
|
||||
open: (initialSection?: SettingsSectionTitle) => {
|
||||
chatSettingsDialogInitialSection = initialSection;
|
||||
chatSettingsDialogOpen = true;
|
||||
}
|
||||
});
|
||||
|
||||
// Global keyboard shortcuts
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
const isCtrlOrCmd = event.ctrlKey || event.metaKey;
|
||||
@@ -213,6 +229,12 @@
|
||||
|
||||
<Toaster richColors />
|
||||
|
||||
<DialogChatSettings
|
||||
open={chatSettingsDialogOpen}
|
||||
onOpenChange={(open) => (chatSettingsDialogOpen = open)}
|
||||
initialSection={chatSettingsDialogInitialSection}
|
||||
/>
|
||||
|
||||
<DialogConversationTitleUpdate
|
||||
bind:open={titleUpdateDialogOpen}
|
||||
currentTitle={titleUpdateCurrentTitle}
|
||||
|
||||
@@ -73,7 +73,7 @@
|
||||
conversationsStore.conversations = mockConversations;
|
||||
}, 0));
|
||||
|
||||
const searchTrigger = screen.getByText('Search conversations');
|
||||
const searchTrigger = screen.getByText('Search');
|
||||
userEvent.click(searchTrigger);
|
||||
}}
|
||||
>
|
||||
|
||||
Vendored
+104
-41
@@ -467,10 +467,6 @@ bool set_socket_opt_impl(socket_t sock, int level, int optname,
|
||||
optlen) == 0;
|
||||
}
|
||||
|
||||
bool set_socket_opt(socket_t sock, int level, int optname, int optval) {
|
||||
return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval));
|
||||
}
|
||||
|
||||
bool set_socket_opt_time(socket_t sock, int level, int optname,
|
||||
time_t sec, time_t usec) {
|
||||
#ifdef _WIN32
|
||||
@@ -2218,7 +2214,7 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port,
|
||||
#ifdef _WIN32
|
||||
// Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so
|
||||
// remove the option.
|
||||
detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0);
|
||||
set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0);
|
||||
#endif
|
||||
|
||||
bool dummy;
|
||||
@@ -4373,6 +4369,7 @@ make_multipart_content_provider(const UploadFormDataItems &items,
|
||||
struct MultipartState {
|
||||
std::vector<std::string> owned;
|
||||
std::vector<MultipartSegment> segs;
|
||||
std::vector<char> buf = std::vector<char>(CPPHTTPLIB_SEND_BUFSIZ);
|
||||
};
|
||||
auto state = std::make_shared<MultipartState>();
|
||||
state->owned = std::move(owned);
|
||||
@@ -4381,19 +4378,49 @@ make_multipart_content_provider(const UploadFormDataItems &items,
|
||||
state->segs = std::move(segs);
|
||||
|
||||
return [state](size_t offset, size_t length, DataSink &sink) -> bool {
|
||||
// Buffer multiple small segments into fewer, larger writes to avoid
|
||||
// excessive TCP packets when there are many form data items (#2410)
|
||||
auto &buf = state->buf;
|
||||
auto buf_size = buf.size();
|
||||
size_t buf_len = 0;
|
||||
size_t remaining = length;
|
||||
|
||||
// Find the first segment containing 'offset'
|
||||
size_t pos = 0;
|
||||
for (const auto &seg : state->segs) {
|
||||
// Loop invariant: pos <= offset (proven by advancing pos only when
|
||||
// offset - pos >= seg.size, i.e., the segment doesn't contain offset)
|
||||
if (seg.size > 0 && offset - pos < seg.size) {
|
||||
size_t seg_offset = offset - pos;
|
||||
size_t available = seg.size - seg_offset;
|
||||
size_t to_write = (std::min)(available, length);
|
||||
return sink.write(seg.data + seg_offset, to_write);
|
||||
}
|
||||
size_t seg_idx = 0;
|
||||
for (; seg_idx < state->segs.size(); seg_idx++) {
|
||||
const auto &seg = state->segs[seg_idx];
|
||||
if (seg.size > 0 && offset - pos < seg.size) { break; }
|
||||
pos += seg.size;
|
||||
}
|
||||
return true; // past end (shouldn't be reached when content_length is exact)
|
||||
|
||||
size_t seg_offset = (seg_idx < state->segs.size()) ? offset - pos : 0;
|
||||
|
||||
for (; seg_idx < state->segs.size() && remaining > 0; seg_idx++) {
|
||||
const auto &seg = state->segs[seg_idx];
|
||||
size_t available = seg.size - seg_offset;
|
||||
size_t to_copy = (std::min)(available, remaining);
|
||||
const char *src = seg.data + seg_offset;
|
||||
seg_offset = 0; // only the first segment has a non-zero offset
|
||||
|
||||
while (to_copy > 0) {
|
||||
size_t space = buf_size - buf_len;
|
||||
size_t chunk = (std::min)(to_copy, space);
|
||||
std::memcpy(buf.data() + buf_len, src, chunk);
|
||||
buf_len += chunk;
|
||||
src += chunk;
|
||||
to_copy -= chunk;
|
||||
remaining -= chunk;
|
||||
|
||||
if (buf_len == buf_size) {
|
||||
if (!sink.write(buf.data(), buf_len)) { return false; }
|
||||
buf_len = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (buf_len > 0) { return sink.write(buf.data(), buf_len); }
|
||||
return true;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -5264,13 +5291,18 @@ bool setup_client_tls_session(const std::string &host, tls::ctx_t &ctx,
|
||||
*/
|
||||
|
||||
void default_socket_options(socket_t sock) {
|
||||
detail::set_socket_opt(sock, SOL_SOCKET,
|
||||
set_socket_opt(sock, SOL_SOCKET,
|
||||
#ifdef SO_REUSEPORT
|
||||
SO_REUSEPORT,
|
||||
SO_REUSEPORT,
|
||||
#else
|
||||
SO_REUSEADDR,
|
||||
SO_REUSEADDR,
|
||||
#endif
|
||||
1);
|
||||
1);
|
||||
}
|
||||
|
||||
bool set_socket_opt(socket_t sock, int level, int optname, int optval) {
|
||||
return detail::set_socket_opt_impl(sock, level, optname, &optval,
|
||||
sizeof(optval));
|
||||
}
|
||||
|
||||
std::string get_bearer_token_auth(const Request &req) {
|
||||
@@ -7418,6 +7450,8 @@ bool Server::read_content_core(
|
||||
return false;
|
||||
}
|
||||
|
||||
req.body_consumed_ = true;
|
||||
|
||||
if (req.is_multipart_form_data()) {
|
||||
if (!multipart_form_data_parser.is_valid()) {
|
||||
res.status = StatusCode::BadRequest_400;
|
||||
@@ -7688,9 +7722,7 @@ bool Server::listen_internal() {
|
||||
detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO,
|
||||
write_timeout_sec_, write_timeout_usec_);
|
||||
|
||||
if (tcp_nodelay_) {
|
||||
detail::set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1);
|
||||
}
|
||||
if (tcp_nodelay_) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); }
|
||||
|
||||
if (!task_queue->enqueue(
|
||||
[this, sock]() { process_and_close_socket(sock); })) {
|
||||
@@ -8036,8 +8068,19 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
return write_response(strm, close_connection, req, res);
|
||||
}
|
||||
|
||||
// RFC 9112 §6.3: Reject requests with both a non-zero Content-Length and
|
||||
// any Transfer-Encoding to prevent request smuggling. Content-Length: 0 is
|
||||
// tolerated for compatibility with existing clients.
|
||||
if (req.get_header_value_u64("Content-Length") > 0 &&
|
||||
req.has_header("Transfer-Encoding")) {
|
||||
connection_closed = true;
|
||||
res.status = StatusCode::BadRequest_400;
|
||||
return write_response(strm, close_connection, req, res);
|
||||
}
|
||||
|
||||
// Check if the request URI doesn't exceed the limit
|
||||
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
|
||||
connection_closed = true;
|
||||
res.status = StatusCode::UriTooLong_414;
|
||||
output_error_log(Error::ExceedUriMaxLength, &req);
|
||||
return write_response(strm, close_connection, req, res);
|
||||
@@ -8066,6 +8109,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
if (req.has_header("Accept")) {
|
||||
const auto &accept_header = req.get_header_value("Accept");
|
||||
if (!detail::parse_accept_header(accept_header, req.accept_content_types)) {
|
||||
connection_closed = true;
|
||||
res.status = StatusCode::BadRequest_400;
|
||||
output_error_log(Error::HTTPParsing, &req);
|
||||
return write_response(strm, close_connection, req, res);
|
||||
@@ -8075,6 +8119,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
if (req.has_header("Range")) {
|
||||
const auto &range_header_value = req.get_header_value("Range");
|
||||
if (!detail::parse_range_header(range_header_value, req.ranges)) {
|
||||
connection_closed = true;
|
||||
res.status = StatusCode::RangeNotSatisfiable_416;
|
||||
output_error_log(Error::InvalidRangeHeader, &req);
|
||||
return write_response(strm, close_connection, req, res);
|
||||
@@ -8202,6 +8247,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
}
|
||||
}
|
||||
#endif
|
||||
auto ret = false;
|
||||
if (routed) {
|
||||
if (res.status == -1) {
|
||||
res.status = req.ranges.empty() ? StatusCode::OK_200
|
||||
@@ -8209,6 +8255,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
}
|
||||
|
||||
// Serve file content by using a content provider
|
||||
auto file_open_error = false;
|
||||
if (!res.file_content_path_.empty()) {
|
||||
const auto &path = res.file_content_path_;
|
||||
auto mm = std::make_shared<detail::mmap>(path.c_str());
|
||||
@@ -8218,37 +8265,53 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
res.content_provider_ = nullptr;
|
||||
res.status = StatusCode::NotFound_404;
|
||||
output_error_log(Error::OpenFile, &req);
|
||||
return write_response(strm, close_connection, req, res);
|
||||
}
|
||||
file_open_error = true;
|
||||
} else {
|
||||
auto content_type = res.file_content_content_type_;
|
||||
if (content_type.empty()) {
|
||||
content_type = detail::find_content_type(
|
||||
path, file_extension_and_mimetype_map_, default_file_mimetype_);
|
||||
}
|
||||
|
||||
auto content_type = res.file_content_content_type_;
|
||||
if (content_type.empty()) {
|
||||
content_type = detail::find_content_type(
|
||||
path, file_extension_and_mimetype_map_, default_file_mimetype_);
|
||||
res.set_content_provider(
|
||||
mm->size(), content_type,
|
||||
[mm](size_t offset, size_t length, DataSink &sink) -> bool {
|
||||
sink.write(mm->data() + offset, length);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
res.set_content_provider(
|
||||
mm->size(), content_type,
|
||||
[mm](size_t offset, size_t length, DataSink &sink) -> bool {
|
||||
sink.write(mm->data() + offset, length);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
if (detail::range_error(req, res)) {
|
||||
if (file_open_error) {
|
||||
ret = write_response(strm, close_connection, req, res);
|
||||
} else if (detail::range_error(req, res)) {
|
||||
res.body.clear();
|
||||
res.content_length_ = 0;
|
||||
res.content_provider_ = nullptr;
|
||||
res.status = StatusCode::RangeNotSatisfiable_416;
|
||||
return write_response(strm, close_connection, req, res);
|
||||
ret = write_response(strm, close_connection, req, res);
|
||||
} else {
|
||||
ret = write_response_with_content(strm, close_connection, req, res);
|
||||
}
|
||||
|
||||
return write_response_with_content(strm, close_connection, req, res);
|
||||
} else {
|
||||
if (res.status == -1) { res.status = StatusCode::NotFound_404; }
|
||||
|
||||
return write_response(strm, close_connection, req, res);
|
||||
ret = write_response(strm, close_connection, req, res);
|
||||
}
|
||||
|
||||
// Drain any unconsumed request body to prevent request smuggling on
|
||||
// keep-alive connections.
|
||||
if (!req.body_consumed_ && detail::expect_content(req)) {
|
||||
int drain_status = 200; // required by read_content signature
|
||||
if (!detail::read_content(
|
||||
strm, req, payload_max_length_, drain_status, nullptr,
|
||||
[](const char *, size_t, size_t, size_t) { return true; }, false)) {
|
||||
// Body exceeds payload limit or read error — close the connection
|
||||
// to prevent leftover bytes from being misinterpreted.
|
||||
connection_closed = true;
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool Server::is_valid() const { return true; }
|
||||
|
||||
Vendored
+12
-2
@@ -8,8 +8,8 @@
|
||||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.39.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002700"
|
||||
#define CPPHTTPLIB_VERSION "0.40.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002800"
|
||||
|
||||
#ifdef _WIN32
|
||||
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00
|
||||
@@ -1266,6 +1266,7 @@ struct Request {
|
||||
bool is_multipart_form_data() const;
|
||||
|
||||
// private members...
|
||||
bool body_consumed_ = false;
|
||||
size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT;
|
||||
size_t content_length_ = 0;
|
||||
ContentProvider content_provider_;
|
||||
@@ -1475,6 +1476,8 @@ using SocketOptions = std::function<void(socket_t sock)>;
|
||||
|
||||
void default_socket_options(socket_t sock);
|
||||
|
||||
bool set_socket_opt(socket_t sock, int level, int optname, int optval);
|
||||
|
||||
const char *status_message(int status);
|
||||
|
||||
std::string to_string(Error error);
|
||||
@@ -1564,6 +1567,13 @@ ssize_t write_headers(Stream &strm, const Headers &headers);
|
||||
bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec,
|
||||
time_t usec);
|
||||
|
||||
size_t get_multipart_content_length(const UploadFormDataItems &items,
|
||||
const std::string &boundary);
|
||||
|
||||
ContentProvider
|
||||
make_multipart_content_provider(const UploadFormDataItems &items,
|
||||
const std::string &boundary);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
class Server {
|
||||
|
||||
Reference in New Issue
Block a user