mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-02 02:27:41 +02:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1f5d15e665 | |||
| c46758d28f | |||
| bf934f28db | |||
| 5c1a7b8355 | |||
| 59d840209a | |||
| ff934e29bc | |||
| ee051c1e4e | |||
| e6f6770515 | |||
| 48cda24c11 | |||
| 871f1a2d2f | |||
| 20197b6fe3 | |||
| ba38f3becc | |||
| 37f230dd7c | |||
| a308e584ca | |||
| d0fa2c9fbb | |||
| 9bcb4eff4d | |||
| 6861f6509a | |||
| 1743d98057 | |||
| 7ca0c9cca7 | |||
| 8c60b8a2be | |||
| 287b5b1eab | |||
| a73bbd5d92 | |||
| ded446b34c | |||
| f8d4abae86 | |||
| 3d5acab3e7 | |||
| 9900b29c3a | |||
| dc8d14c582 | |||
| 93dfbc1291 | |||
| 3cba8bba18 | |||
| 112c78159f | |||
| 0fac87b157 | |||
| 0a524f2404 |
@@ -4,7 +4,7 @@
|
||||
|
||||
# Define the CANN base image for easier version updates later
|
||||
ARG CHIP_TYPE=910b
|
||||
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-${CHIP_TYPE}-openeuler24.03-py3.11
|
||||
ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.5.0-${CHIP_TYPE}-openeuler24.03-py3.11
|
||||
|
||||
# ==============================================================================
|
||||
# BUILD STAGE
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
ARG TARGETARCH
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential git cmake libssl-dev
|
||||
apt-get install -y gcc-14 g++-14 build-essential git cmake libssl-dev
|
||||
|
||||
ENV CC=gcc-14 CXX=g++-14
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -55,8 +57,9 @@ RUN apt-get update \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
&& pip install --upgrade pip setuptools wheel \
|
||||
&& pip install -r requirements.txt \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ARG ASCEND_VERSION=8.1.RC1.alpha001-910b-openeuler22.03-py3.10
|
||||
ARG ASCEND_VERSION=8.5.0-910b-openeuler22.03-py3.10
|
||||
|
||||
FROM ascendai/cann:$ASCEND_VERSION AS build
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@
|
||||
effectiveStdenv ? if useCuda then cudaPackages.backendStdenv else stdenv,
|
||||
enableStatic ? effectiveStdenv.hostPlatform.isStatic,
|
||||
precompileMetalShaders ? false,
|
||||
useWebUi ? true,
|
||||
}:
|
||||
|
||||
let
|
||||
@@ -164,6 +165,7 @@ effectiveStdenv.mkDerivation (finalAttrs: {
|
||||
cmakeFlags =
|
||||
[
|
||||
(cmakeBool "LLAMA_BUILD_SERVER" true)
|
||||
(cmakeBool "LLAMA_BUILD_WEBUI" useWebUi)
|
||||
(cmakeBool "BUILD_SHARED_LIBS" (!enableStatic))
|
||||
(cmakeBool "CMAKE_SKIP_BUILD_RPATH" true)
|
||||
(cmakeBool "GGML_NATIVE" false)
|
||||
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
distribution: zulu
|
||||
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
uses: android-actions/setup-android@9fc6c4e9069bf8d3d10b2204b1fb8f6ef7065407 # v3
|
||||
with:
|
||||
log-accepted-android-sdk-licenses: false
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
- name: Set container image
|
||||
id: cann-image
|
||||
run: |
|
||||
image="ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc2-910b-openeuler24.03-py3.11' || '8.3.rc2-310p-openeuler24.03-py3.11' }}"
|
||||
image="ascendai/cann:${{ matrix.chip_type == '910b' && '8.5.0-910b-openeuler24.03-py3.11' || '8.5.0-310p-openeuler24.03-py3.11' }}"
|
||||
echo "image=${image}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Pull container image
|
||||
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
# save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Setup ${{ matrix.sys }}
|
||||
uses: msys2/setup-msys2@v2
|
||||
uses: msys2/setup-msys2@cafece8e6baf9247cf9b1bf95097b0b983cc558d # v2
|
||||
with:
|
||||
update: true
|
||||
msystem: ${{matrix.sys}}
|
||||
|
||||
@@ -36,18 +36,16 @@ jobs:
|
||||
matrix:
|
||||
config:
|
||||
# Multi-stage build
|
||||
# Note: the arm64 images are failing, which prevents the amd64 images from being built
|
||||
# https://github.com/ggml-org/llama.cpp/issues/11888
|
||||
#- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false }
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "cuda cuda12", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04", cuda_version: "12.4.0", ubuntu_version: "22.04" }
|
||||
- { tag: "cuda13", dockerfile: ".devops/cuda-new.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04", cuda_version: "13.1.0", ubuntu_version: "24.04" }
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04-s390x" }
|
||||
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "openvino", dockerfile: ".devops/openvino.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/arm64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "cuda cuda12", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "12.4.0", ubuntu_version: "22.04" }
|
||||
- { tag: "cuda13", dockerfile: ".devops/cuda-new.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04", cuda_version: "13.1.0", ubuntu_version: "24.04" }
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04-s390x" }
|
||||
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-24.04" }
|
||||
- { tag: "openvino", dockerfile: ".devops/openvino.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-24.04" }
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v6
|
||||
@@ -56,15 +54,15 @@ jobs:
|
||||
|
||||
- name: Set up QEMU
|
||||
if: ${{ matrix.config.tag != 's390x' }}
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3
|
||||
with:
|
||||
image: tonistiigi/binfmt:qemu-v7.0.0-28
|
||||
image: tonistiigi/binfmt:qemu-v10.2.1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -127,7 +125,7 @@ jobs:
|
||||
|
||||
- name: Build and push Full Docker image (tagged + versioned)
|
||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.full == true }}
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
@@ -152,7 +150,7 @@ jobs:
|
||||
|
||||
- name: Build and push Light Docker image (tagged + versioned)
|
||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }}
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
@@ -177,7 +175,7 @@ jobs:
|
||||
|
||||
- name: Build and push Server Docker image (tagged + versioned)
|
||||
if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }}
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
runs-on: ubuntu-slim
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: editorconfig-checker/action-editorconfig-checker@v2
|
||||
- uses: editorconfig-checker/action-editorconfig-checker@840e866d93b8e032123c23bac69dece044d4d84c # v2.2.0
|
||||
with:
|
||||
version: v3.0.3
|
||||
- run: editorconfig-checker
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Build package
|
||||
run: cd gguf-py && poetry build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1
|
||||
with:
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
packages-dir: gguf-py/dist
|
||||
|
||||
@@ -31,6 +31,6 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: flake8 Lint
|
||||
uses: py-actions/flake8@v2
|
||||
uses: py-actions/flake8@84ec6726560b6d5bd68f2a5bed83d62b52bb50ba # v2
|
||||
with:
|
||||
plugins: "flake8-no-print"
|
||||
|
||||
@@ -907,7 +907,7 @@ jobs:
|
||||
- name: Set container image
|
||||
id: cann-image
|
||||
run: |
|
||||
image="ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc2-910b-openeuler24.03-py3.11' || '8.3.rc2-310p-openeuler24.03-py3.11' }}"
|
||||
image="ascendai/cann:${{ matrix.chip_type == '910b' && '8.5.0-910b-openeuler24.03-py3.11' || '8.5.0-310p-openeuler24.03-py3.11' }}"
|
||||
echo "image=${image}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Pull container image
|
||||
|
||||
@@ -108,6 +108,7 @@ option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_WEBUI "llama: build the embedded Web UI for server" ON)
|
||||
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
|
||||
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
|
||||
|
||||
|
||||
+17
-1
@@ -1079,7 +1079,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.verbose_prompt = true;
|
||||
}
|
||||
));
|
||||
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
|
||||
add_opt(common_arg(
|
||||
{"--display-prompt"},
|
||||
{"--no-display-prompt"},
|
||||
@@ -2807,6 +2807,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.port = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT"));
|
||||
add_opt(common_arg(
|
||||
{"--reuse-port"},
|
||||
string_format("allow multiple sockets to bind to the same port (default: %s)", params.reuse_port ? "enabled" : "disabled"),
|
||||
[](common_params & params) {
|
||||
params.reuse_port = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_REUSE_PORT"));
|
||||
add_opt(common_arg(
|
||||
{"--path"}, "PATH",
|
||||
string_format("path to serve static files from (default: %s)", params.public_path.c_str()),
|
||||
@@ -2843,6 +2850,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.webui_mcp_proxy = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY"));
|
||||
add_opt(common_arg(
|
||||
{"--tools"}, "TOOL1,TOOL2,...",
|
||||
"experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n"
|
||||
"specify \"all\" to enable all tools\n"
|
||||
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
|
||||
add_opt(common_arg(
|
||||
{"--webui"},
|
||||
{"--no-webui"},
|
||||
|
||||
@@ -287,7 +287,7 @@ void analyze_reasoning::compare_reasoning_presence() {
|
||||
return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())) + p.rest());
|
||||
});
|
||||
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
|
||||
return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
|
||||
});
|
||||
// try the more aggressive parse first, if it fails, fall back to the delimiter one
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
@@ -297,7 +297,7 @@ void analyze_reasoning::compare_reasoning_presence() {
|
||||
if (result.result.success()) {
|
||||
if (!result.tags["pre"].empty() && !result.tags["post"].empty()) {
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
start = trim_whitespace(result.tags["pre"]);
|
||||
start = trim_leading_whitespace(result.tags["pre"]);
|
||||
end = trim_trailing_whitespace(result.tags["post"]);
|
||||
} else if (!result.tags["post"].empty()) {
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
@@ -333,7 +333,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (left_trimmed.empty() && !diff.right.empty()) {
|
||||
if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) {
|
||||
if (start.empty()) {
|
||||
start = right_trimmed;
|
||||
start = trim_leading_whitespace(diff.right);
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -344,7 +344,7 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) {
|
||||
start = seg[seg.size() - 2].value;
|
||||
}
|
||||
end = left_trimmed;
|
||||
end = trim_trailing_whitespace(diff.left);
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
}
|
||||
}
|
||||
@@ -363,15 +363,23 @@ void analyze_reasoning::compare_thinking_enabled() {
|
||||
size_t len = std::min(base.size(), anchor_len);
|
||||
std::string anchor = base.substr(base.size() - len);
|
||||
auto pos = extended.rfind(anchor);
|
||||
if (pos == std::string::npos || pos + len >= extended.size()) continue;
|
||||
if (pos == std::string::npos || pos + len >= extended.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string extra = trim_whitespace(extended.substr(pos + len));
|
||||
if (extra.empty()) continue;
|
||||
if (extra.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto seg = prune_whitespace_segments(segmentize_markers(extra));
|
||||
if (seg.size() == 2 && seg[0].type == segment_type::MARKER && seg[1].type == segment_type::MARKER) {
|
||||
if (start.empty()) start = seg[0].value;
|
||||
if (end.empty()) end = seg[1].value;
|
||||
if (start.empty()) {
|
||||
start = seg[0].value;
|
||||
}
|
||||
if (end.empty()) {
|
||||
end = seg[1].value;
|
||||
}
|
||||
mode = reasoning_mode::TAG_BASED;
|
||||
break;
|
||||
}
|
||||
@@ -423,7 +431,7 @@ void analyze_reasoning::compare_reasoning_scope() {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Detected TOOLS_ONLY reasoning mode\n" ANSI_RESET, __func__);
|
||||
|
||||
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space()));
|
||||
return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space()));
|
||||
});
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B);
|
||||
if (result.result.success()) {
|
||||
@@ -516,7 +524,7 @@ analyze_content::analyze_content(const common_chat_template & tmpl, const analyz
|
||||
// Take the more promising diff
|
||||
std::string pure_content = rdiff.length() > diff_tools.left.length() ? rdiff : diff_tools.left;
|
||||
auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) {
|
||||
return p.tag("pre", p.marker()) + p.space() + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
|
||||
return p.tag("pre", p.marker() + p.space()) + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest();
|
||||
});
|
||||
auto result = parser_wrapped.parse_anywhere_and_extract(pure_content);
|
||||
start = result.tags["pre"];
|
||||
|
||||
@@ -656,6 +656,38 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||
return true;
|
||||
}
|
||||
|
||||
// simple glob: * matches non-/ chars, ** matches anything including /
|
||||
static inline bool glob_match(const char * pattern, const char * str) {
|
||||
if (*pattern == '\0') {
|
||||
return *str == '\0';
|
||||
}
|
||||
if (pattern[0] == '*' && pattern[1] == '*') {
|
||||
const char * p = pattern + 2;
|
||||
if (*p == '/') p++;
|
||||
if (glob_match(p, str)) return true;
|
||||
if (*str != '\0') return glob_match(pattern, str + 1);
|
||||
return false;
|
||||
}
|
||||
if (*pattern == '*') {
|
||||
const char * p = pattern + 1;
|
||||
for (; *str != '\0' && *str != '/'; str++) {
|
||||
if (glob_match(p, str)) return true;
|
||||
}
|
||||
return glob_match(p, str);
|
||||
}
|
||||
if (*pattern == '?' && *str != '\0' && *str != '/') {
|
||||
return glob_match(pattern + 1, str + 1);
|
||||
}
|
||||
if (*pattern == *str) {
|
||||
return glob_match(pattern + 1, str + 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool glob_match(const std::string & pattern, const std::string & str) {
|
||||
return glob_match(pattern.c_str(), str.c_str());
|
||||
}
|
||||
|
||||
//
|
||||
// Filesystem utils
|
||||
//
|
||||
|
||||
@@ -573,6 +573,7 @@ struct common_params {
|
||||
|
||||
// server params
|
||||
int32_t port = 8080; // server listens on this network port
|
||||
bool reuse_port = false; // allow multiple sockets to bind to the same port
|
||||
int32_t timeout_read = 600; // http read timeout in seconds
|
||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
@@ -613,6 +614,9 @@ struct common_params {
|
||||
bool endpoint_props = false; // only control POST requests, not GET
|
||||
bool endpoint_metrics = false;
|
||||
|
||||
// enable built-in tools
|
||||
std::vector<std::string> server_tools;
|
||||
|
||||
// router server configs
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
std::string models_preset = ""; // directory containing model presets for the router server
|
||||
@@ -790,6 +794,8 @@ std::string string_from(const std::vector<int> & values);
|
||||
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
|
||||
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
|
||||
|
||||
bool glob_match(const std::string & pattern, const std::string & str);
|
||||
|
||||
//
|
||||
// Filesystem utils
|
||||
//
|
||||
|
||||
+16
-4
@@ -548,6 +548,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
|
||||
return best;
|
||||
}
|
||||
|
||||
static bool gguf_filename_is_model(const std::string & filepath) {
|
||||
if (!string_ends_with(filepath, ".gguf")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string filename = filepath;
|
||||
if (auto pos = filename.rfind('/'); pos != std::string::npos) {
|
||||
filename = filename.substr(pos + 1);
|
||||
}
|
||||
|
||||
return filename.find("mmproj") == std::string::npos &&
|
||||
filename.find("imatrix") == std::string::npos;
|
||||
}
|
||||
|
||||
static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
const std::string & tag) {
|
||||
std::vector<std::string> tags;
|
||||
@@ -561,8 +575,7 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
for (const auto & t : tags) {
|
||||
std::regex pattern(t + "[.-]", std::regex::icase);
|
||||
for (const auto & f : files) {
|
||||
if (string_ends_with(f.path, ".gguf") &&
|
||||
f.path.find("mmproj") == std::string::npos &&
|
||||
if (gguf_filename_is_model(f.path) &&
|
||||
std::regex_search(f.path, pattern)) {
|
||||
return f;
|
||||
}
|
||||
@@ -570,8 +583,7 @@ static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
}
|
||||
|
||||
for (const auto & f : files) {
|
||||
if (string_ends_with(f.path, ".gguf") &&
|
||||
f.path.find("mmproj") == std::string::npos) {
|
||||
if (gguf_filename_is_model(f.path)) {
|
||||
return f;
|
||||
}
|
||||
}
|
||||
|
||||
+172
-35
@@ -26,6 +26,8 @@ namespace nl = nlohmann;
|
||||
#include <windows.h>
|
||||
#else
|
||||
#define HOME_DIR "HOME"
|
||||
#include <unistd.h>
|
||||
#include <pwd.h>
|
||||
#endif
|
||||
|
||||
namespace hf_cache {
|
||||
@@ -38,6 +40,7 @@ static fs::path get_cache_directory() {
|
||||
const char * var;
|
||||
fs::path path;
|
||||
} entries[] = {
|
||||
{"LLAMA_CACHE", fs::path()},
|
||||
{"HF_HUB_CACHE", fs::path()},
|
||||
{"HUGGINGFACE_HUB_CACHE", fs::path()},
|
||||
{"HF_HOME", fs::path("hub")},
|
||||
@@ -50,6 +53,13 @@ static fs::path get_cache_directory() {
|
||||
return entry.path.empty() ? base : base / entry.path;
|
||||
}
|
||||
}
|
||||
#ifndef _WIN32
|
||||
const struct passwd * pw = getpwuid(getuid());
|
||||
|
||||
if (pw->pw_dir && *pw->pw_dir) {
|
||||
return fs::path(pw->pw_dir) / ".cache" / "huggingface" / "hub";
|
||||
}
|
||||
#endif
|
||||
throw std::runtime_error("Failed to determine HF cache directory");
|
||||
}();
|
||||
|
||||
@@ -325,9 +335,15 @@ hf_files get_repo_files(const std::string & repo_id,
|
||||
if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
|
||||
file.oid = item["lfs"]["oid"].get<std::string>();
|
||||
}
|
||||
if (item["lfs"].contains("size") && item["lfs"]["size"].is_number()) {
|
||||
file.size = item["lfs"]["size"].get<size_t>();
|
||||
}
|
||||
} else if (item.contains("oid") && item["oid"].is_string()) {
|
||||
file.oid = item["oid"].get<std::string>();
|
||||
}
|
||||
if (file.size == 0 && item.contains("size") && item["size"].is_number()) {
|
||||
file.size = item["size"].get<size_t>();
|
||||
}
|
||||
|
||||
if (!file.oid.empty() && !is_valid_oid(file.oid)) {
|
||||
LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str());
|
||||
@@ -487,6 +503,34 @@ std::string finalize_file(const hf_file & file) {
|
||||
|
||||
// delete everything after this line, one day
|
||||
|
||||
// copied from download.cpp without the tag part
|
||||
struct gguf_split_info {
|
||||
std::string prefix; // tag included
|
||||
int index;
|
||||
int count;
|
||||
};
|
||||
|
||||
static gguf_split_info get_gguf_split_info(const std::string & path) {
|
||||
static const std::regex re_split("^(.+)-([0-9]{5})-of-([0-9]{5})$", std::regex::icase);
|
||||
std::smatch m;
|
||||
|
||||
std::string prefix = path;
|
||||
if (!string_remove_suffix(prefix, ".gguf")) {
|
||||
return {};
|
||||
}
|
||||
|
||||
int index = 1;
|
||||
int count = 1;
|
||||
|
||||
if (std::regex_match(prefix, m, re_split)) {
|
||||
index = std::stoi(m[2].str());
|
||||
count = std::stoi(m[3].str());
|
||||
prefix = m[1].str();
|
||||
}
|
||||
|
||||
return {std::move(prefix), index, count};
|
||||
}
|
||||
|
||||
static std::pair<std::string, std::string> parse_manifest_name(std::string & filename) {
|
||||
static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)");
|
||||
std::smatch match;
|
||||
@@ -504,25 +548,30 @@ static std::string make_old_cache_filename(const std::string & owner,
|
||||
return result;
|
||||
}
|
||||
|
||||
static void migrate_single_file(const fs::path & old_cache,
|
||||
const std::string & owner,
|
||||
const std::string & repo,
|
||||
const nl::json & node,
|
||||
const hf_files & files) {
|
||||
struct migrate_file {
|
||||
std::string path;
|
||||
std::string sha256;
|
||||
size_t size;
|
||||
fs::path old_path;
|
||||
fs::path etag_path;
|
||||
const hf_file * file;
|
||||
};
|
||||
|
||||
if (!node.contains("rfilename") ||
|
||||
!node.contains("lfs") ||
|
||||
!node["lfs"].contains("sha256")) {
|
||||
return;
|
||||
}
|
||||
using migrate_files = std::vector<migrate_file>;
|
||||
|
||||
std::string path = node["rfilename"];
|
||||
std::string sha256 = node["lfs"]["sha256"];
|
||||
static bool collect_file(const fs::path & old_cache,
|
||||
const std::string & owner,
|
||||
const std::string & repo,
|
||||
const std::string & path,
|
||||
const std::string & sha256,
|
||||
const hf_files & files,
|
||||
migrate_files & to_migrate) {
|
||||
|
||||
const hf_file * file = nullptr;
|
||||
|
||||
const hf_file * file_info = nullptr;
|
||||
for (const auto & f : files) {
|
||||
if (f.path == path) {
|
||||
file_info = &f;
|
||||
file = &f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -532,41 +581,105 @@ static void migrate_single_file(const fs::path & old_cache,
|
||||
fs::path etag_path = old_path.string() + ".etag";
|
||||
|
||||
if (!fs::exists(old_path)) {
|
||||
if (fs::exists(etag_path)) {
|
||||
LOG_WRN("%s: %s is orphan, deleting...\n", __func__, etag_path.string().c_str());
|
||||
fs::remove(etag_path);
|
||||
if (file && fs::exists(file->final_path)) {
|
||||
return true;
|
||||
}
|
||||
return;
|
||||
LOG_WRN("%s: %s not found in old cache or HF cache\n", __func__, old_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!file_info) {
|
||||
LOG_WRN("%s: %s not found in current repo, ignoring...\n", __func__, old_filename.c_str());
|
||||
return;
|
||||
} else if (!sha256.empty() && !file_info->oid.empty() && sha256 != file_info->oid) {
|
||||
LOG_WRN("%s: %s is not up to date (sha256 mismatch), ignoring...\n", __func__, old_filename.c_str());
|
||||
return;
|
||||
if (!file) {
|
||||
LOG_WRN("%s: %s not found in current repo\n", __func__, old_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!sha256.empty() && !file->oid.empty() && sha256 != file->oid) {
|
||||
LOG_WRN("%s: %s is not up to date (sha256 mismatch)\n", __func__, old_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (file->size > 0) {
|
||||
size_t size = fs::file_size(old_path);
|
||||
if (size != file->size) {
|
||||
LOG_WRN("%s: %s has wrong size %zu (expected %zu)\n", __func__, old_filename.c_str(), size, file->size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
to_migrate.push_back({path, sha256, file->size, old_path, etag_path, file});
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool collect_files(const fs::path & old_cache,
|
||||
const std::string & owner,
|
||||
const std::string & repo,
|
||||
const nl::json & node,
|
||||
const hf_files & files,
|
||||
migrate_files & to_migrate) {
|
||||
|
||||
if (!node.contains("rfilename") ||
|
||||
!node.contains("lfs") ||
|
||||
!node["lfs"].contains("sha256")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string path = node["rfilename"];
|
||||
std::string sha256 = node["lfs"]["sha256"];
|
||||
|
||||
auto split = get_gguf_split_info(path);
|
||||
|
||||
if (split.count <= 1) {
|
||||
return collect_file(old_cache, owner, repo, path, sha256, files, to_migrate);
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> splits;
|
||||
|
||||
for (const auto & f : files) {
|
||||
auto split_f = get_gguf_split_info(f.path);
|
||||
if (split_f.count == split.count && split_f.prefix == split.prefix) {
|
||||
// sadly the manifest only provides the sha256 of the first file (index == 1)
|
||||
// the rest will be verified using the size...
|
||||
std::string f_sha256 = (split_f.index == 1) ? sha256 : "";
|
||||
splits.emplace_back(f.path, f_sha256);
|
||||
}
|
||||
}
|
||||
|
||||
if ((int)splits.size() != split.count) {
|
||||
LOG_WRN("%s: expected %d split files but found %d in repo\n", __func__, split.count, (int)splits.size());
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto & [f_path, f_sha256] : splits) {
|
||||
if (!collect_file(old_cache, owner, repo, f_path, f_sha256, files, to_migrate)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool migrate_file(const migrate_file & file) {
|
||||
std::error_code ec;
|
||||
|
||||
fs::path new_path(file_info->local_path);
|
||||
fs::path new_path(file.file->local_path);
|
||||
fs::create_directories(new_path.parent_path(), ec);
|
||||
|
||||
if (!fs::exists(new_path, ec)) {
|
||||
fs::rename(old_path, new_path, ec);
|
||||
fs::rename(file.old_path, new_path, ec);
|
||||
if (ec) {
|
||||
fs::copy_file(old_path, new_path, ec);
|
||||
fs::copy_file(file.old_path, new_path, ec);
|
||||
if (ec) {
|
||||
LOG_WRN("%s: failed to move/copy %s: %s\n", __func__, old_path.string().c_str(), ec.message().c_str());
|
||||
return;
|
||||
LOG_ERR("%s: failed to move/copy %s: %s\n", __func__, file.old_path.string().c_str(), ec.message().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
fs::remove(old_path, ec);
|
||||
fs::remove(file.old_path, ec);
|
||||
}
|
||||
fs::remove(etag_path, ec);
|
||||
fs::remove(file.etag_path, ec);
|
||||
|
||||
std::string filename = finalize_file(*file_info);
|
||||
LOG_INF("%s: migrated %s -> %s\n", __func__, old_filename.c_str(), filename.c_str());
|
||||
std::string filename = finalize_file(*file.file);
|
||||
LOG_INF("%s: migrated %s -> %s\n", __func__, file.old_path.filename().string().c_str(), filename.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
|
||||
@@ -614,19 +727,43 @@ void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
|
||||
continue;
|
||||
}
|
||||
|
||||
migrate_files to_migrate;
|
||||
bool ok = true;
|
||||
|
||||
try {
|
||||
std::ifstream manifest(entry.path());
|
||||
auto json = nl::json::parse(manifest);
|
||||
|
||||
for (const char * key : {"ggufFile", "mmprojFile"}) {
|
||||
if (json.contains(key)) {
|
||||
migrate_single_file(old_cache, owner, repo, json[key], files);
|
||||
if (!collect_files(old_cache, owner, repo, json[key], files, to_migrate)) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LOG_WRN("%s: migration skipped: one or more files failed validation\n", __func__);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto & file : to_migrate) {
|
||||
if (!migrate_file(file)) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LOG_WRN("%s: migration failed: could not migrate all files\n", __func__);
|
||||
continue;
|
||||
}
|
||||
|
||||
LOG_INF("%s: migration complete, deleting manifest: %s\n", __func__, entry.path().string().c_str());
|
||||
fs::remove(entry.path());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ struct hf_file {
|
||||
std::string final_path;
|
||||
std::string oid;
|
||||
std::string repo_id;
|
||||
size_t size = 0; // only for the migration
|
||||
};
|
||||
|
||||
using hf_files = std::vector<hf_file>;
|
||||
|
||||
+12
-11
@@ -115,9 +115,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
break;
|
||||
}
|
||||
case REASONING_BUDGET_FORCING:
|
||||
// force_pos is advanced in apply(), not here.
|
||||
// This ensures the first forced token isn't skipped when the sampler
|
||||
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
||||
}
|
||||
break;
|
||||
case REASONING_BUDGET_DONE:
|
||||
break;
|
||||
@@ -144,14 +146,6 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// advance to next forced token (done here rather than in accept so that
|
||||
// the first forced token isn't skipped when starting in FORCING state)
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
||||
}
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
||||
@@ -261,3 +255,10 @@ struct llama_sampler * common_reasoning_budget_init(
|
||||
common_reasoning_budget_state initial_state) {
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
}
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
|
||||
if (!smpl) {
|
||||
return REASONING_BUDGET_IDLE;
|
||||
}
|
||||
return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
|
||||
}
|
||||
|
||||
@@ -51,3 +51,5 @@ struct llama_sampler * common_reasoning_budget_init(
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state);
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
|
||||
|
||||
+46
-10
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_map>
|
||||
@@ -109,6 +110,7 @@ struct common_sampler {
|
||||
common_params_sampling params;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
struct llama_sampler * rbudget;
|
||||
struct llama_sampler * chain;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
@@ -188,6 +190,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
llama_sampler * grmr = nullptr;
|
||||
llama_sampler * rbudget = nullptr;
|
||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
@@ -270,7 +273,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
if (grmr) {
|
||||
if (grmr && !params.grammar_lazy) {
|
||||
try {
|
||||
for (const auto & token : prefill_tokens) {
|
||||
llama_sampler_accept(grmr, token);
|
||||
@@ -284,15 +287,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler — added first so it can force tokens before other samplers
|
||||
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
|
||||
samplers.push_back(common_reasoning_budget_init(
|
||||
// reasoning budget sampler
|
||||
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
|
||||
rbudget = common_reasoning_budget_init(
|
||||
vocab,
|
||||
params.reasoning_budget_start,
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens,
|
||||
prefill_tokens));
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
|
||||
prefill_tokens);
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
@@ -383,6 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .rbudget = */ rbudget,
|
||||
/* .chain = */ chain,
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
@@ -398,11 +402,27 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
}
|
||||
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
llama_sampler_free(gsmpl->rbudget);
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
delete gsmpl;
|
||||
}
|
||||
|
||||
static bool grammar_should_apply(struct common_sampler * gsmpl) {
|
||||
if (!gsmpl->grmr) {
|
||||
return false;
|
||||
}
|
||||
if (!gsmpl->rbudget) {
|
||||
return true;
|
||||
}
|
||||
if (gsmpl->params.grammar_lazy) {
|
||||
// if grammar is lazy, only apply when reasoning budget is not active
|
||||
const auto state = common_reasoning_budget_get_state(gsmpl->rbudget);
|
||||
return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
if (!gsmpl) {
|
||||
return;
|
||||
@@ -410,6 +430,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
|
||||
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
|
||||
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
|
||||
|
||||
llama_sampler_accept(gsmpl->rbudget, token);
|
||||
|
||||
if (gsmpl->grmr && accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
@@ -431,6 +456,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
return new common_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .rbudget = */ llama_sampler_clone(gsmpl->rbudget),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
@@ -500,6 +526,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
llama_token id = LLAMA_TOKEN_NULL;
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & rbudget = gsmpl->rbudget;
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
@@ -511,7 +538,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
|
||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
|
||||
|
||||
// TODO: simplify
|
||||
gsmpl->cur.resize(1);
|
||||
@@ -524,7 +552,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
if (grammar_first) {
|
||||
// apply reasoning budget first
|
||||
llama_sampler_apply(rbudget, &cur_p);
|
||||
|
||||
if (grammar_first && grammar_should_apply(gsmpl)) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
||||
@@ -532,7 +563,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
|
||||
id = cur_p.data[cur_p.selected].id;
|
||||
|
||||
if (grammar_first) {
|
||||
if (grammar_first || !grammar_should_apply(gsmpl)) {
|
||||
return id;
|
||||
}
|
||||
|
||||
@@ -553,7 +584,12 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
llama_sampler_apply(rbudget, &cur_p);
|
||||
|
||||
if (grammar_should_apply(gsmpl)) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
||||
+175
-11
@@ -486,7 +486,7 @@ class ModelBase:
|
||||
elif quant_method == "modelopt":
|
||||
# Mixed-precision ModelOpt models: NVFP4 tensors are handled by
|
||||
# _generate_nvfp4_tensors; FP8 tensors have 1D weight_scale and
|
||||
# are dequantized here. input_scale tensors are unused.
|
||||
# are dequantized here. k/v scale tensors are unused.
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale"):
|
||||
weight_name = name.removesuffix("_scale")
|
||||
@@ -494,7 +494,7 @@ class ModelBase:
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), None)
|
||||
tensors_to_remove.append(name)
|
||||
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
|
||||
if name.endswith((".k_scale", ".v_scale")):
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method is not None:
|
||||
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
||||
@@ -542,7 +542,6 @@ class ModelBase:
|
||||
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
# Handle gate/up expert tensor fusion if enabled
|
||||
@@ -607,7 +606,12 @@ class ModelBase:
|
||||
def _nvfp4_scale2_is_trivial(scale2: Tensor) -> bool:
|
||||
return scale2.numel() <= 1 and abs(float(scale2.float().sum()) - 1.0) < 1e-6
|
||||
|
||||
def _repack_nvfp4(self, new_name: str, weight: Tensor, scale: Tensor, scale2: Tensor):
|
||||
def _repack_nvfp4(self, name: str, weight: Tensor, scale: Tensor, scale2: Tensor, input_scale: Tensor):
|
||||
if "language_model." in name:
|
||||
name = name.replace("language_model.", "")
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
raw, shape = self._nvfp4_pack(weight, scale)
|
||||
logger.info(f"Repacked {new_name} with shape {shape} and quantization NVFP4")
|
||||
self.gguf_writer.add_tensor(new_name, raw, raw_dtype=gguf.GGMLQuantizationType.NVFP4)
|
||||
@@ -619,10 +623,18 @@ class ModelBase:
|
||||
logger.info(f" + {scale_name} (per-tensor NVFP4 scale2, shape [{scale2_f32.size}])")
|
||||
self.gguf_writer.add_tensor(scale_name, scale2_f32)
|
||||
|
||||
# Emit per-tensor input_scale as a separate F32 tensor when non-trivial
|
||||
if not self._nvfp4_scale2_is_trivial(input_scale):
|
||||
input_scale_f32 = input_scale.float().numpy().flatten()
|
||||
input_scale_name = new_name.replace(".weight", ".input_scale")
|
||||
logger.info(f" + {input_scale_name} (per-tensor NVFP4 input_scale, shape [{input_scale_f32.size}])")
|
||||
self.gguf_writer.add_tensor(input_scale_name, input_scale_f32)
|
||||
|
||||
def _generate_nvfp4_tensors(self):
|
||||
# Per-layer expert merging to avoid holding all experts in memory
|
||||
expert_blocks: dict[tuple[int, str], list[tuple[int, np.ndarray]]] = {}
|
||||
expert_scales: dict[tuple[int, str], list[tuple[int, float]]] = {}
|
||||
expert_input_scales: dict[tuple[int, str], list[tuple[int, float]]] = {}
|
||||
expert_shapes: dict[tuple[int, str], list[int]] = {}
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
|
||||
consumed: list[str] = []
|
||||
@@ -632,6 +644,7 @@ class ModelBase:
|
||||
continue
|
||||
scale_name = name.replace(".weight", ".weight_scale")
|
||||
scale2_name = name.replace(".weight", ".weight_scale_2")
|
||||
input_scale_name = name.replace(".weight", ".input_scale")
|
||||
if scale_name not in self.model_tensors:
|
||||
continue
|
||||
# Force eager materialization of lazy tensors
|
||||
@@ -643,11 +656,14 @@ class ModelBase:
|
||||
continue
|
||||
|
||||
scale2 = LazyTorchTensor.to_eager(self.model_tensors.get(scale2_name, lambda: torch.tensor(1.0))())
|
||||
input_scale = LazyTorchTensor.to_eager(self.model_tensors.get(input_scale_name, lambda: torch.tensor(1.0))())
|
||||
|
||||
# Mark tensors for removal from model_tensors (already written to gguf)
|
||||
consumed.extend([name, scale_name])
|
||||
if scale2_name in self.model_tensors:
|
||||
consumed.append(scale2_name)
|
||||
if input_scale_name in self.model_tensors:
|
||||
consumed.append(input_scale_name)
|
||||
|
||||
# Check if this is a per-expert tensor
|
||||
m = re.search(r'\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$', name)
|
||||
@@ -663,34 +679,37 @@ class ModelBase:
|
||||
if key not in expert_blocks:
|
||||
expert_blocks[key] = []
|
||||
expert_scales[key] = []
|
||||
expert_input_scales[key] = []
|
||||
expert_shapes[key] = shape
|
||||
expert_blocks[key].append((expert_id, raw.copy()))
|
||||
# Collect per-expert scale2 (scalar per expert)
|
||||
expert_scales[key].append((expert_id, float(scale2.float().sum())))
|
||||
# Collect per-expert input_scale (scalar per expert)
|
||||
expert_input_scales[key].append((expert_id, float(input_scale.float().sum())))
|
||||
|
||||
# Flush when all experts for this (layer, proj) are collected
|
||||
if n_experts > 0 and len(expert_blocks[key]) >= n_experts:
|
||||
self._flush_nvfp4_experts(key, expert_blocks, expert_scales, expert_shapes, bid, proj_type)
|
||||
self._flush_nvfp4_experts(key, expert_blocks, expert_scales, expert_input_scales, expert_shapes, bid, proj_type)
|
||||
else:
|
||||
new_name = self.map_tensor_name(name)
|
||||
self._repack_nvfp4(new_name, weight, scale, scale2)
|
||||
self._repack_nvfp4(name, weight, scale, scale2, input_scale)
|
||||
|
||||
# Flush any remaining experts (fallback if n_experts was unknown)
|
||||
for (bid, proj_type) in list(expert_blocks.keys()):
|
||||
self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_shapes, bid, proj_type)
|
||||
self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_input_scales, expert_shapes, bid, proj_type)
|
||||
|
||||
# Remove consumed tensors so get_tensors/modify_tensors won't see them
|
||||
for name in consumed:
|
||||
self.model_tensors.pop(name, None)
|
||||
|
||||
# Remove unused auxiliary tensors (input_scale, k_scale, v_scale)
|
||||
# Remove any remaining unused auxiliary tensors
|
||||
for name in list(self.model_tensors.keys()):
|
||||
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
|
||||
if name.endswith((".k_scale", ".v_scale")):
|
||||
del self.model_tensors[name]
|
||||
|
||||
def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes, bid, proj_type):
|
||||
def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_input_scales, expert_shapes, bid, proj_type):
|
||||
experts = expert_blocks.pop(key)
|
||||
scales = expert_scales.pop(key)
|
||||
input_scales = expert_input_scales.pop(key)
|
||||
shape = expert_shapes.pop(key)
|
||||
|
||||
experts.sort(key=lambda x: x[0])
|
||||
@@ -708,6 +727,14 @@ class ModelBase:
|
||||
logger.info(f" + {scale_name} (per-expert NVFP4 scale2, shape [{len(scales)}])")
|
||||
self.gguf_writer.add_tensor(scale_name, scale_vals)
|
||||
|
||||
# Emit per-expert input_scale tensor if any expert has non-trivial input_scale
|
||||
input_scales.sort(key=lambda x: x[0])
|
||||
input_scale_vals = np.array([s[1] for s in input_scales], dtype=np.float32)
|
||||
if not np.allclose(input_scale_vals, 1.0, atol=1e-6):
|
||||
input_scale_name = new_name.replace(".weight", ".input_scale")
|
||||
logger.info(f" + {input_scale_name} (per-expert NVFP4 input_scale, shape [{len(input_scales)}])")
|
||||
self.gguf_writer.add_tensor(input_scale_name, input_scale_vals)
|
||||
|
||||
del experts, merged
|
||||
|
||||
def prepare_tensors(self):
|
||||
@@ -1311,6 +1338,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
|
||||
# ref: https://huggingface.co/aari1995/German_Semantic_V3
|
||||
res = "jina-v2-de"
|
||||
if chkhsh == "0fe1cf6eda062318a1af7270f3331a85c539a01778ff948e24388e949c5282f4":
|
||||
# ref: https://huggingface.co/evilfreelancer/ruGPT3XL
|
||||
res = "gpt-2"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
@@ -5011,6 +5041,97 @@ class _LinearAttentionVReorderBase(Qwen3NextModel):
|
||||
perm[dim], perm[dim + 1] = perm[dim + 1], perm[dim]
|
||||
return tensor.permute(*perm).contiguous().reshape(*shape)
|
||||
|
||||
def _transform_nvfp4_weight(self, name: str, weight: Tensor, scale: Tensor) -> tuple[Tensor, Tensor]:
|
||||
if not name.endswith((
|
||||
".linear_attn.in_proj_qkv.weight",
|
||||
".linear_attn.in_proj_z.weight",
|
||||
".linear_attn.in_proj_a.weight",
|
||||
".linear_attn.in_proj_b.weight",
|
||||
".linear_attn.out_proj.weight",
|
||||
)):
|
||||
return weight, scale
|
||||
|
||||
num_k_heads = self.hparams["linear_num_key_heads"]
|
||||
num_v_heads = self.hparams["linear_num_value_heads"]
|
||||
head_k_dim = self.hparams["linear_key_head_dim"]
|
||||
head_v_dim = self.hparams["linear_value_head_dim"]
|
||||
num_v_per_k = num_v_heads // num_k_heads
|
||||
|
||||
def unpack_nibbles(qs: Tensor) -> Tensor:
|
||||
lo = torch.bitwise_and(qs, 0x0F)
|
||||
hi = torch.bitwise_right_shift(qs, 4)
|
||||
return torch.stack((lo, hi), dim=-1).reshape(*qs.shape[:-1], qs.shape[-1] * 2)
|
||||
|
||||
def pack_nibbles(codes: Tensor) -> Tensor:
|
||||
codes = codes.reshape(*codes.shape[:-1], codes.shape[-1] // 2, 2)
|
||||
lo = torch.bitwise_and(codes[..., 0], 0x0F)
|
||||
hi = torch.bitwise_left_shift(torch.bitwise_and(codes[..., 1], 0x0F), 4)
|
||||
return torch.bitwise_or(lo, hi).contiguous()
|
||||
|
||||
def apply_col_perm(qs: Tensor, scales: Tensor, col_perm: Tensor) -> tuple[Tensor, Tensor]:
|
||||
assert qs.ndim >= 2
|
||||
assert scales.ndim >= 2
|
||||
|
||||
k = qs.shape[-1] * 2
|
||||
assert col_perm.numel() == k
|
||||
assert k % 16 == 0
|
||||
|
||||
group_cols = col_perm.reshape(-1, 16)
|
||||
group_starts = group_cols[:, 0]
|
||||
expected = group_starts.unsqueeze(1) + torch.arange(16, dtype=col_perm.dtype)
|
||||
assert torch.equal(group_cols, expected)
|
||||
assert torch.all(group_starts % 16 == 0)
|
||||
|
||||
group_perm = (group_starts // 16).to(dtype=torch.long)
|
||||
expected_groups = torch.arange(scales.shape[-1], dtype=torch.long)
|
||||
assert group_perm.numel() == scales.shape[-1]
|
||||
assert torch.equal(torch.sort(group_perm).values, expected_groups)
|
||||
|
||||
codes = unpack_nibbles(qs)
|
||||
codes = codes.index_select(-1, col_perm.to(device=qs.device, dtype=torch.long))
|
||||
qs = pack_nibbles(codes)
|
||||
scales = scales.index_select(-1, group_perm.to(device=scales.device))
|
||||
return qs, scales
|
||||
|
||||
def reorder_rows(qs: Tensor, scales: Tensor, head_dim: int) -> tuple[Tensor, Tensor]:
|
||||
row_perm = self._reorder_v_heads(
|
||||
torch.arange(num_v_heads * head_dim, dtype=torch.long).unsqueeze(-1),
|
||||
0, num_k_heads, num_v_per_k, head_dim,
|
||||
).squeeze(-1)
|
||||
return (
|
||||
qs.index_select(0, row_perm.to(device=qs.device)),
|
||||
scales.index_select(0, row_perm.to(device=scales.device)),
|
||||
)
|
||||
|
||||
if name.endswith(".linear_attn.in_proj_qkv.weight"):
|
||||
q_dim = head_k_dim * num_k_heads
|
||||
k_dim = head_k_dim * num_k_heads
|
||||
q = weight[:q_dim]
|
||||
k = weight[q_dim:q_dim + k_dim]
|
||||
v = weight[q_dim + k_dim:]
|
||||
q_scale = scale[:q_dim]
|
||||
k_scale = scale[q_dim:q_dim + k_dim]
|
||||
v_scale = scale[q_dim + k_dim:]
|
||||
v, v_scale = reorder_rows(v, v_scale, head_v_dim)
|
||||
return torch.cat([q, k, v], dim=0), torch.cat([q_scale, k_scale, v_scale], dim=0)
|
||||
|
||||
if name.endswith(".linear_attn.in_proj_z.weight"):
|
||||
weight, scale = reorder_rows(weight, scale, head_v_dim)
|
||||
elif name.endswith((".linear_attn.in_proj_a.weight", ".linear_attn.in_proj_b.weight")):
|
||||
weight, scale = reorder_rows(weight, scale, 1)
|
||||
elif name.endswith(".linear_attn.out_proj.weight"):
|
||||
col_perm = self._reorder_v_heads(
|
||||
torch.arange(num_v_heads * head_v_dim, dtype=torch.long).unsqueeze(0),
|
||||
1, num_k_heads, num_v_per_k, head_v_dim,
|
||||
).squeeze(0)
|
||||
weight, scale = apply_col_perm(weight, scale, col_perm)
|
||||
|
||||
return weight, scale
|
||||
|
||||
def _repack_nvfp4(self, name: str, weight: Tensor, scale: Tensor, scale2: Tensor, input_scale: Tensor):
|
||||
weight, scale = self._transform_nvfp4_weight(name, weight, scale)
|
||||
super()._repack_nvfp4(name, weight, scale, scale2, input_scale)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
num_k_heads = self.hparams.get("linear_num_key_heads", 0)
|
||||
num_v_heads = self.hparams.get("linear_num_value_heads", 0)
|
||||
@@ -5100,6 +5221,47 @@ class GPT2Model(TextModel):
|
||||
yield from super().modify_tensors(data_torch, new_name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("RuGPT3XLForCausalLM")
|
||||
class RuGPT3XLModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GPT2
|
||||
|
||||
_qkv_parts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Fuse separate Q, K, V projections into a single QKV tensor
|
||||
if ".self_attn.q_proj." in name or ".self_attn.k_proj." in name or ".self_attn.v_proj." in name:
|
||||
suffix = "weight" if name.endswith(".weight") else "bias"
|
||||
part = "q" if ".q_proj." in name else ("k" if ".k_proj." in name else "v")
|
||||
key = f"{part}.{suffix}"
|
||||
|
||||
assert bid is not None
|
||||
if self._qkv_parts is None:
|
||||
self._qkv_parts = [{} for _ in range(self.block_count)]
|
||||
self._qkv_parts[bid][key] = data_torch
|
||||
|
||||
q_key, k_key, v_key = f"q.{suffix}", f"k.{suffix}", f"v.{suffix}"
|
||||
if all(k in self._qkv_parts[bid] for k in [q_key, k_key, v_key]):
|
||||
q = self._qkv_parts[bid].pop(q_key)
|
||||
k = self._qkv_parts[bid].pop(k_key)
|
||||
v = self._qkv_parts[bid].pop(v_key)
|
||||
data_torch = torch.cat([q, k, v], dim=0)
|
||||
name = self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_QKV, bid, f".{suffix}")
|
||||
logger.debug(f"Fused Q/K/V {suffix} for layer {bid} -> {name}")
|
||||
else:
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._qkv_parts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
parts = [f"({i}){k}" for i, d in enumerate(self._qkv_parts) for k in d.keys()]
|
||||
if len(parts) > 0:
|
||||
raise ValueError(f"Unprocessed Q/K/V parts: {parts}")
|
||||
|
||||
|
||||
@ModelBase.register("PhiForCausalLM")
|
||||
class Phi2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.PHI2
|
||||
@@ -6988,6 +7150,8 @@ class DeepseekOCRVisionModel(MmprojModel):
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if ".rel_pos_h" in name or '.rel_pos_w' in name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if ".neck." in name or ".net_" in name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
|
||||
@@ -178,6 +178,7 @@ pre_computed_hashes = [
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
{"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/evilfreelancer/ruGPT3XL", "chkhsh": "0fe1cf6eda062318a1af7270f3331a85c539a01778ff948e24388e949c5282f4"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
+78
-54
@@ -42,12 +42,22 @@ The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the abi
|
||||
|
||||
### Ascend NPU
|
||||
|
||||
**Verified devices**
|
||||
You can retrieve your Ascend device IDs using the following command:
|
||||
|
||||
| Ascend NPU | Status |
|
||||
|:-----------------------------:|:-------:|
|
||||
| Atlas 300T A2 | Support |
|
||||
| Atlas 300I Duo | Support |
|
||||
```sh
|
||||
lspci -n | grep -Eo '19e5:d[0-9a-f]{3}' | cut -d: -f2
|
||||
```
|
||||
|
||||
**Devices**
|
||||
|
||||
| Device Id | Product Series | Product Models | Chip Model | Verified Status |
|
||||
|:---------:|----------------|----------------|:----------:|:---------------:|
|
||||
| d803 | Atlas A3 Train | | 910C | |
|
||||
| d803 | Atlas A3 Infer | | 910C | |
|
||||
| d802 | Atlas A2 Train | | 910B | |
|
||||
| d802 | Atlas A2 Infer | Atlas 300I A2 | 910B | Support |
|
||||
| d801 | Atlas Train | | 910 | |
|
||||
| d500 | Atlas Infer | Atlas 300I Duo | 310P | Support |
|
||||
|
||||
*Notes:*
|
||||
|
||||
@@ -57,6 +67,9 @@ The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the abi
|
||||
|
||||
## Model Supports
|
||||
|
||||
<details>
|
||||
<summary>Text-only</summary>
|
||||
|
||||
| Model Name | FP16 | Q4_0 | Q8_0 |
|
||||
|:----------------------------|:-----:|:----:|:----:|
|
||||
| Llama-2 | √ | √ | √ |
|
||||
@@ -118,8 +131,11 @@ The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the abi
|
||||
| Trillion-7B-preview | √ | √ | √ |
|
||||
| Ling models | √ | √ | √ |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Multimodal</summary>
|
||||
|
||||
**Multimodal**
|
||||
| Model Name | FP16 | Q4_0 | Q8_0 |
|
||||
|:----------------------------|:-----:|:----:|:----:|
|
||||
| LLaVA 1.5 models, LLaVA 1.6 models | x | x | x |
|
||||
@@ -134,15 +150,22 @@ The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the abi
|
||||
| GLM-EDGE | √ | √ | √ |
|
||||
| Qwen2-VL | √ | √ | √ |
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## DataType Supports
|
||||
|
||||
| DataType | Status |
|
||||
|:----------------------:|:-------:|
|
||||
| FP16 | Support |
|
||||
| Q8_0 | Support |
|
||||
| Q4_0 | Support |
|
||||
| DataType | 910B | 310P |
|
||||
|:----------------------:|:-------:|:-------:|
|
||||
| FP16 | Support | Support |
|
||||
| Q8_0 | Support | Partial |
|
||||
| Q4_0 | Support | Partial |
|
||||
| BF16 | Support | |
|
||||
|
||||
> **310P note**
|
||||
> - `Q8_0`: data transform / buffer path is implemented, and `GET_ROWS` is supported, but quantized `MUL_MAT` / `MUL_MAT_ID` are not supported.
|
||||
> - `Q4_0`: data transform / buffer path is implemented, but quantized `MUL_MAT` / `MUL_MAT_ID` are not supported.
|
||||
|
||||
## Docker
|
||||
|
||||
@@ -160,7 +183,20 @@ npu-smi info
|
||||
|
||||
# Select the cards that you want to use, make sure these cards are not used by someone.
|
||||
# Following using cards of device0.
|
||||
docker run --name llamacpp --device /dev/davinci0 --device /dev/davinci_manager --device /dev/devmm_svm --device /dev/hisi_hdc -v /usr/local/dcmi:/usr/local/dcmi -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info -v /PATH_TO_YOUR_MODELS/:/app/models -it llama-cpp-cann -m /app/models/MODEL_PATH -ngl 32 -p "Building a website can be done in 10 simple steps:"
|
||||
docker run --name llamacpp \
|
||||
--device /dev/davinci0 \
|
||||
--device /dev/davinci_manager \
|
||||
--device /dev/devmm_svm \
|
||||
--device /dev/hisi_hdc \
|
||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
||||
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
||||
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
||||
-v /PATH_TO_YOUR_MODELS/:/app/models \
|
||||
-it llama-cpp-cann \
|
||||
-m /app/models/MODEL_PATH \
|
||||
-ngl 32 \
|
||||
-p "Building a website can be done in 10 simple steps:"
|
||||
```
|
||||
|
||||
*Notes:*
|
||||
@@ -171,69 +207,57 @@ docker run --name llamacpp --device /dev/davinci0 --device /dev/davinci_manager
|
||||
|
||||
### I. Setup Environment
|
||||
|
||||
1. **Install Ascend Driver and firmware**
|
||||
1. **Configure Ascend user and group**
|
||||
|
||||
```sh
|
||||
# create driver running user.
|
||||
sudo groupadd -g HwHiAiUser
|
||||
sudo groupadd HwHiAiUser
|
||||
sudo useradd -g HwHiAiUser -d /home/HwHiAiUser -m HwHiAiUser -s /bin/bash
|
||||
sudo usermod -aG HwHiAiUser $USER
|
||||
|
||||
# download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system
|
||||
# and install driver.
|
||||
sudo sh Ascend-hdk-910b-npu-driver_x.x.x_linux-{arch}.run --full --install-for-all
|
||||
```
|
||||
|
||||
Once installed, run `npu-smi info` to check whether driver is installed successfully.
|
||||
2. **Install dependencies**
|
||||
|
||||
**Ubuntu/Debian:**
|
||||
```sh
|
||||
+-------------------------------------------------------------------------------------------+
|
||||
| npu-smi 24.1.rc2 Version: 24.1.rc2 |
|
||||
+----------------------+---------------+----------------------------------------------------+
|
||||
| NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page)|
|
||||
| Chip | Bus-Id | AICore(%) Memory-Usage(MB) HBM-Usage(MB) |
|
||||
+======================+===============+====================================================+
|
||||
| 2 xxx | OK | 64.4 51 15 / 15 |
|
||||
| 0 | 0000:01:00.0 | 0 1873 / 15077 0 / 32768 |
|
||||
+======================+===============+====================================================+
|
||||
| 5 xxx | OK | 64.0 52 15 / 15 |
|
||||
| 0 | 0000:81:00.0 | 0 1874 / 15077 0 / 32768 |
|
||||
+======================+===============+====================================================+
|
||||
| No running processes found in NPU 2 |
|
||||
+======================+===============+====================================================+
|
||||
| No running processes found in NPU 5 |
|
||||
+======================+===============+====================================================+
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y gcc python3 python3-pip linux-headers-$(uname -r)
|
||||
```
|
||||
|
||||
2. **Install Ascend Firmware**
|
||||
**RHEL/CentOS:**
|
||||
```sh
|
||||
# download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system
|
||||
# and install driver.
|
||||
sudo sh Ascend-hdk-910b-npu-firmware_x.x.x.x.X.run --full
|
||||
sudo yum makecache
|
||||
sudo yum install -y gcc python3 python3-pip kernel-headers-$(uname -r) kernel-devel-$(uname -r)
|
||||
```
|
||||
If the following message appears, firmware is installed successfully.
|
||||
|
||||
3. **Install CANN (driver + toolkit)**
|
||||
|
||||
> The `Ascend-cann` package includes both the driver and toolkit.
|
||||
> `$ARCH` can be `x86_64` or `aarch64`, `$CHIP` can be `910b` or `310p`.
|
||||
|
||||
```sh
|
||||
Firmware package installed successfully!
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.5.T63/Ascend-cann_8.5.0_linux-$ARCH.run
|
||||
sudo bash ./Ascend-cann_8.5.0_linux-$ARCH.run --install
|
||||
|
||||
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.5.T63/Ascend-cann-$CHIP-ops_8.5.0_linux-$ARCH.run
|
||||
sudo bash ./Ascend-cann-$CHIP-ops_8.5.0_linux-$ARCH.run --install
|
||||
```
|
||||
|
||||
4. **Verify installation**
|
||||
|
||||
3. **Install CANN toolkit and kernels**
|
||||
|
||||
CANN toolkit and kernels can be obtained from the official [CANN Toolkit](https://www.hiascend.com/zh/developer/download/community/result?module=cann) page.
|
||||
|
||||
Please download the corresponding version that satified your system. The minimum version required is 8.0.RC2.alpha002 and here is the install command.
|
||||
```sh
|
||||
pip3 install attrs numpy decorator sympy cffi pyyaml pathlib2 psutil protobuf scipy requests absl-py wheel typing_extensions
|
||||
sh Ascend-cann-toolkit_8.0.RC2.alpha002_linux-aarch64.run --install
|
||||
sh Ascend-cann-kernels-910b_8.0.RC2.alpha002_linux.run --install
|
||||
npu-smi info
|
||||
```
|
||||
|
||||
Set Ascend Variables:
|
||||
If device information is displayed correctly, the driver is functioning properly.
|
||||
|
||||
```sh
|
||||
echo "source ~/Ascend/ascend-toolkit/set_env.sh" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
# Set environment variables (adjust path if needed)
|
||||
source /usr/local/Ascend/cann/set_env.sh
|
||||
|
||||
python3 -c "import acl; print(acl.get_soc_name())"
|
||||
```
|
||||
|
||||
Upon a successful installation, CANN is enabled for the available ascend devices.
|
||||
If the command outputs the chip model, the installation was successful.
|
||||
|
||||
### II. Build llama.cpp
|
||||
|
||||
|
||||
@@ -460,6 +460,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
if(NOT GGML_CPU_ALL_VARIANTS)
|
||||
set(MARCH_STR "rv64gc")
|
||||
if (GGML_RVV)
|
||||
string(APPEND MARCH_STR "v")
|
||||
endif()
|
||||
|
||||
if (GGML_RV_ZFH)
|
||||
string(APPEND MARCH_STR "_zfh")
|
||||
endif()
|
||||
@@ -467,7 +471,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (GGML_XTHEADVECTOR)
|
||||
string(APPEND MARCH_STR "_xtheadvector")
|
||||
elseif (GGML_RVV)
|
||||
string(APPEND MARCH_STR "_v")
|
||||
if (GGML_RV_ZVFH)
|
||||
string(APPEND MARCH_STR "_zvfh")
|
||||
endif()
|
||||
@@ -475,12 +478,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
string(APPEND MARCH_STR "_zvfbfwma")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_RV_ZICBOP)
|
||||
string(APPEND MARCH_STR "_zicbop")
|
||||
endif()
|
||||
if (GGML_RV_ZIHINTPAUSE)
|
||||
string(APPEND MARCH_STR "_zihintpause")
|
||||
endif()
|
||||
|
||||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||
else()
|
||||
# Begin with the lowest baseline
|
||||
|
||||
@@ -2871,8 +2871,12 @@ struct ggml_cplan ggml_graph_plan(
|
||||
const int64_t ne11 = node->src[1]->ne[1]; // H
|
||||
const int64_t ne12 = node->src[1]->ne[2]; // Channels In
|
||||
|
||||
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
||||
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
||||
GGML_ASSERT(node->src[0]->type == GGML_TYPE_F16 || node->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(node->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
cur += ggml_type_size(node->src[0]->type) * ne00 * ne01 * ne02 * ne03;
|
||||
cur += ggml_type_size(node->src[0]->type) * ne10 * ne11 * ne12;
|
||||
|
||||
} break;
|
||||
case GGML_OP_TOP_K:
|
||||
{
|
||||
|
||||
+50
-19
@@ -6923,16 +6923,15 @@ void ggml_compute_forward_conv_3d(
|
||||
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_transpose_2d
|
||||
|
||||
void ggml_compute_forward_conv_transpose_2d(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
template <typename kernel_t>
|
||||
static void ggml_compute_forward_conv_transpose_2d_impl(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
@@ -6943,7 +6942,7 @@ void ggml_compute_forward_conv_transpose_2d(
|
||||
|
||||
const int nk = ne00*ne01*ne02*ne03;
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
if (ith == 0) {
|
||||
@@ -6951,12 +6950,12 @@ void ggml_compute_forward_conv_transpose_2d(
|
||||
|
||||
// permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
|
||||
{
|
||||
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
||||
kernel_t * const wdata = (kernel_t *) params->wdata + 0;
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
|
||||
ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
|
||||
const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02);
|
||||
kernel_t * dst_data = wdata + i02*ne01*ne00*ne03;
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
|
||||
@@ -6968,13 +6967,17 @@ void ggml_compute_forward_conv_transpose_2d(
|
||||
|
||||
// permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
|
||||
{
|
||||
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
|
||||
kernel_t * const wdata = (kernel_t *) params->wdata + nk;
|
||||
for (int i12 = 0; i12 < ne12; i12++) {
|
||||
for (int i11 = 0; i11 < ne11; i11++) {
|
||||
const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
|
||||
ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
|
||||
kernel_t * dst_data = wdata + i11*ne10*ne12;
|
||||
for (int i10 = 0; i10 < ne10; i10++) {
|
||||
dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
|
||||
if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
|
||||
dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
|
||||
} else {
|
||||
dst_data[i10*ne12 + i12] = src[i10];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6996,21 +6999,27 @@ void ggml_compute_forward_conv_transpose_2d(
|
||||
const int ip0 = dp*ith;
|
||||
const int ip1 = MIN(ip0 + dp, np);
|
||||
|
||||
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
||||
ggml_fp16_t * const wdata_src = wdata + nk;
|
||||
kernel_t * const wdata = (kernel_t *) params->wdata + 0;
|
||||
kernel_t * const wdata_src = wdata + nk;
|
||||
|
||||
for (int i2 = ip0; i2 < ip1; i2++) { // Cout
|
||||
float * dst_data = (float *)((char *) dst->data + i2*nb2);
|
||||
ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
|
||||
kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
|
||||
for (int i11 = 0; i11 < ne11; i11++) {
|
||||
for (int i10 = 0; i10 < ne10; i10++) {
|
||||
const int i1n = i11*ne10*ne12 + i10*ne12;
|
||||
for (int i01 = 0; i01 < ne01; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
float v = 0;
|
||||
ggml_vec_dot_f16(ne03, &v, 0,
|
||||
wdata_src + i1n, 0,
|
||||
wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
|
||||
if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
|
||||
ggml_vec_dot_f16(ne03, &v, 0,
|
||||
wdata_src + i1n, 0,
|
||||
wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
|
||||
} else {
|
||||
ggml_vec_dot_f32(ne03, &v, 0,
|
||||
wdata_src + i1n, 0,
|
||||
wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
|
||||
}
|
||||
dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
|
||||
}
|
||||
}
|
||||
@@ -7019,6 +7028,28 @@ void ggml_compute_forward_conv_transpose_2d(
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_conv_transpose_2d(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_conv_transpose_2d_impl<ggml_fp16_t>(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_conv_transpose_2d_impl<float>(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_2d_dw
|
||||
|
||||
struct ggml_conv_2d_dw_params {
|
||||
|
||||
@@ -799,6 +799,22 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||
#endif // CUDART_VERSION >= 12050
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
|
||||
#ifdef FP8_AVAILABLE
|
||||
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
|
||||
#if defined(GGML_USE_HIP) && defined(CDNA3)
|
||||
// ROCm dose not support fp8 in software on devices with fp8 hardware,
|
||||
// but CDNA3 supports only e4m3_fnuz (no inf).
|
||||
const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast<const __hip_fp8_e4m3_fnuz *>(&bits);
|
||||
#else
|
||||
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
|
||||
#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP)
|
||||
return static_cast<float>(xf) / 2;
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP8_AVAILABLE
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
|
||||
const uint8_t sign_bit = (x < 0.0f) << 3;
|
||||
float ax = fabsf(x) * e;
|
||||
@@ -931,6 +947,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
|
||||
static constexpr int qi = QI_MXFP4;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_NVFP4> {
|
||||
static constexpr int qk = QK_NVFP4;
|
||||
static constexpr int qr = QR_NVFP4;
|
||||
static constexpr int qi = QI_NVFP4;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
|
||||
static constexpr int qk = QK_K;
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "conv2d-transpose.cuh"
|
||||
#include "ggml.h"
|
||||
#include "convert.cuh"
|
||||
|
||||
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
|
||||
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
|
||||
const int out_h, const int kernel_w, const int kernel_h, const int stride,
|
||||
const int c_in, const int c_out, const int batches) {
|
||||
template <typename kernel_t>
|
||||
static __global__ void conv2d_transpose_kernel(const float * __restrict__ input,
|
||||
const kernel_t * __restrict__ kernel,
|
||||
float * __restrict__ output,
|
||||
const int in_w,
|
||||
const int in_h,
|
||||
const int out_w,
|
||||
const int out_h,
|
||||
const int kernel_w,
|
||||
const int kernel_h,
|
||||
const int stride,
|
||||
const int c_in,
|
||||
const int c_out,
|
||||
const int batches) {
|
||||
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
const int total_elements = out_w * out_h * c_out * batches;
|
||||
@@ -26,24 +34,32 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const
|
||||
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
|
||||
for (int kh = 0; kh < kernel_h; ++kh) {
|
||||
int in_y = out_y_idx - kh;
|
||||
if (in_y < 0 || in_y % stride) continue;
|
||||
if (in_y < 0 || in_y % stride) {
|
||||
continue;
|
||||
}
|
||||
in_y /= stride;
|
||||
if (in_y >= in_h) continue;
|
||||
if (in_y >= in_h) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int kw = 0; kw < kernel_w; ++kw) {
|
||||
int in_x = out_x_idx - kw;
|
||||
if (in_x < 0 || in_x % stride) continue;
|
||||
if (in_x < 0 || in_x % stride) {
|
||||
continue;
|
||||
}
|
||||
in_x /= stride;
|
||||
if (in_x >= in_w) continue;
|
||||
if (in_x >= in_w) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
|
||||
const int kernel_idx =
|
||||
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
|
||||
|
||||
float input_val = input[input_idx];
|
||||
half kern_val = kernel[kernel_idx];
|
||||
float input_val = input[input_idx];
|
||||
kernel_t kern_val = kernel[kernel_idx];
|
||||
|
||||
accumulator += input_val * (float) kern_val;
|
||||
accumulator += input_val * ggml_cuda_cast<float>(kern_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -56,11 +72,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
const ggml_tensor * kernel = dst->src[0];
|
||||
const ggml_tensor * input = dst->src[1];
|
||||
|
||||
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float * input_data = (const float *) input->data;
|
||||
float * output_data = (float *) dst->data;
|
||||
const half * kernel_data = (const half *) kernel->data;
|
||||
const void * kernel_data = kernel->data;
|
||||
|
||||
const int input_w = input->ne[0];
|
||||
const int input_h = input->ne[1];
|
||||
@@ -82,10 +99,17 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
GGML_ASSERT(ggml_is_contiguous(kernel));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
const int total = (output_w * output_h * channels_out * batches);
|
||||
const int total = output_w * output_h * channels_out * batches;
|
||||
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
|
||||
|
||||
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
|
||||
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
|
||||
channels_in, channels_out, batches);
|
||||
if (kernel->type == GGML_TYPE_F16) {
|
||||
conv2d_transpose_kernel<half><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
|
||||
input_data, (const half *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w,
|
||||
kernel_h, stride, channels_in, channels_out, batches);
|
||||
|
||||
} else {
|
||||
conv2d_transpose_kernel<float><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
|
||||
input_data, (const float *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w,
|
||||
kernel_h, stride, channels_in, channels_out, batches);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -617,6 +617,45 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
|
||||
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static __global__ void dequantize_block_nvfp4(
|
||||
const void * __restrict__ vx,
|
||||
dst_t * __restrict__ yy,
|
||||
const int64_t ne) {
|
||||
const int64_t i = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int64_t base = i * QK_NVFP4;
|
||||
if (base >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const block_nvfp4 * x = (const block_nvfp4 *) vx;
|
||||
const block_nvfp4 & xb = x[i];
|
||||
|
||||
const int sub = tid / (QK_NVFP4_SUB / 2);
|
||||
const int j = tid % (QK_NVFP4_SUB / 2);
|
||||
|
||||
const float d = ggml_cuda_ue4m3_to_fp32(xb.d[sub]);
|
||||
const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
|
||||
|
||||
const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
|
||||
const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
|
||||
|
||||
yy[y0] = ggml_cuda_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
|
||||
yy[y1] = ggml_cuda_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_nvfp4_cuda(
|
||||
const void * vx,
|
||||
dst_t * y,
|
||||
const int64_t k,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(k % QK_NVFP4 == 0);
|
||||
const int nb = k / QK_NVFP4;
|
||||
dequantize_block_nvfp4<<<nb, 32, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
template <typename src_t, typename dst_t>
|
||||
static __global__ void convert_unary(
|
||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
|
||||
@@ -715,6 +754,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
return dequantize_row_iq3_s_cuda;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_cuda;
|
||||
case GGML_TYPE_NVFP4:
|
||||
return dequantize_row_nvfp4_cuda;
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_cont_cuda<float>;
|
||||
case GGML_TYPE_BF16:
|
||||
@@ -766,6 +807,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
return dequantize_row_iq3_s_cuda;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_cuda;
|
||||
case GGML_TYPE_NVFP4:
|
||||
return dequantize_row_nvfp4_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_cont_cuda<half>;
|
||||
case GGML_TYPE_BF16:
|
||||
|
||||
@@ -1297,7 +1297,12 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
|
||||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
|
||||
|
||||
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
|
||||
const bool use_fp16 =
|
||||
src0->type != GGML_TYPE_NVFP4 &&
|
||||
(src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||
ggml_is_contiguous(src0) &&
|
||||
row_diff == src0->ne[1] &&
|
||||
dst->op_params[0] == GGML_PREC_DEFAULT;
|
||||
|
||||
if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
||||
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
|
||||
@@ -4781,6 +4786,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
#ifdef FP8_AVAILABLE
|
||||
case GGML_TYPE_NVFP4:
|
||||
#endif // FP8_AVAILABLE
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
||||
@@ -15,6 +15,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
|
||||
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
|
||||
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
|
||||
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
|
||||
case GGML_TYPE_NVFP4: return vec_dot_nvfp4_q8_1;
|
||||
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
|
||||
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
|
||||
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
|
||||
@@ -41,6 +42,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
|
||||
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
|
||||
case GGML_TYPE_NVFP4: return VDR_NVFP4_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
|
||||
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
|
||||
@@ -626,6 +628,12 @@ static void mul_mat_vec_q_switch_type(
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
||||
break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_NVFP4>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
|
||||
@@ -322,6 +322,38 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
|
||||
return d * sumi;
|
||||
}
|
||||
|
||||
#define VDR_NVFP4_Q8_1_MMVQ 4
|
||||
#define VDR_NVFP4_Q8_1_MMQ 8
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_nvfp4_q8_1(
|
||||
const void * __restrict__ vbq,
|
||||
const block_q8_1 * __restrict__ bq8_1,
|
||||
const int32_t & kbx,
|
||||
const int32_t & iqs) {
|
||||
|
||||
const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq + kbx;
|
||||
float sum = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
|
||||
const int32_t iqs0 = iqs + 2*i;
|
||||
const int32_t iqs1 = iqs0 + 1;
|
||||
const int32_t is = iqs0 >> 1;
|
||||
const int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
|
||||
const int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
|
||||
const block_q8_1 * bq8 = bq8_1 + (is >> 1);
|
||||
const int32_t i8 = ((is & 1) << 2);
|
||||
|
||||
int sumi = ggml_cuda_dp4a(v0.x, get_int_b4(bq8->qs, i8 + 0), 0);
|
||||
sumi = ggml_cuda_dp4a(v0.y, get_int_b4(bq8->qs, i8 + 2), sumi);
|
||||
sumi = ggml_cuda_dp4a(v1.x, get_int_b4(bq8->qs, i8 + 1), sumi);
|
||||
sumi = ggml_cuda_dp4a(v1.y, get_int_b4(bq8->qs, i8 + 3), sumi);
|
||||
|
||||
const float d = ggml_cuda_ue4m3_to_fp32(bq4->d[is]) * __low2float(bq8->ds);
|
||||
sum += d * float(sumi);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
#define VDR_Q2_K_Q8_1_MMVQ 1
|
||||
#define VDR_Q2_K_Q8_1_MMQ 4
|
||||
|
||||
|
||||
Vendored
+3
-2
@@ -6,9 +6,10 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if CUDART_VERSION >= 12050
|
||||
#if CUDART_VERSION >= 11080
|
||||
#include <cuda_fp8.h>
|
||||
#endif // CUDART_VERSION >= 12050
|
||||
#define FP8_AVAILABLE
|
||||
#endif // CUDART_VERSION >= 11080
|
||||
|
||||
#if CUDART_VERSION >= 12080
|
||||
#include <cuda_fp4.h>
|
||||
|
||||
Vendored
+6
@@ -235,6 +235,12 @@
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
#if HIP_VERSION >= 60200000
|
||||
#include <hip/hip_fp8.h>
|
||||
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
|
||||
#define FP8_AVAILABLE
|
||||
#endif // HIP_VERSION >= 60200000
|
||||
|
||||
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
||||
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
||||
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
||||
|
||||
@@ -1406,6 +1406,13 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
repack_q8_0_q8x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
// IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16])
|
||||
repack_q4_0_q4x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
@@ -1442,6 +1449,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
repack_q8x4x2_q8_0(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_q4x4x2_q4_0(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
@@ -1819,6 +1832,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
if (src0->ne[0] % 32) {
|
||||
return false;
|
||||
@@ -1868,6 +1882,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
if ((src0->ne[0] % 32)) {
|
||||
return false;
|
||||
@@ -2596,8 +2611,26 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
|
||||
delete backend;
|
||||
}
|
||||
|
||||
// Map weight type to its activation quantization family.
|
||||
// Types in the same family produce identical Q8 formats in VTCM and can
|
||||
// safely share quantized activation data via SKIP_QUANTIZE.
|
||||
// When adding a new quantized type, assign it the correct family here.
|
||||
static inline int act_quant_family(enum ggml_type wtype) {
|
||||
switch (wtype) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
return 1; // Q8x4x2
|
||||
default:
|
||||
return 0; // unknown / not quantized
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
|
||||
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
|
||||
return (op0 && op0->src[1] == op1->src[1] &&
|
||||
act_quant_family(op0->src[0]->type) == act_quant_family(op1->src[0]->type) &&
|
||||
act_quant_family(op0->src[0]->type) != 0);
|
||||
}
|
||||
|
||||
static inline bool is_compute_op(ggml_tensor *node)
|
||||
@@ -3364,6 +3397,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
"please update hexagon_type to match ggml_type");
|
||||
static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
|
||||
const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
|
||||
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
|
||||
|
||||
@@ -30,6 +30,12 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
|
||||
};
|
||||
|
||||
// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
|
||||
// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
|
||||
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0,
|
||||
};
|
||||
|
||||
static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
-127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0,
|
||||
1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0,
|
||||
@@ -46,7 +52,8 @@ static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned
|
||||
|
||||
// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
|
||||
#define HMX_X4X2_SCALES_PER_BLK 8
|
||||
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes
|
||||
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL)
|
||||
#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4)
|
||||
|
||||
static inline void swap_ptr(void **p1, void **p2) {
|
||||
void *t = *p1;
|
||||
@@ -78,9 +85,11 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return (size_t)nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
|
||||
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
|
||||
case HTP_TYPE_Q8_0:
|
||||
return (size_t)nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
|
||||
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
|
||||
case HTP_TYPE_MXFP4:
|
||||
return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
@@ -284,6 +293,87 @@ static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||||
}
|
||||
|
||||
// --- MXFP4 E8M0 scale conversion and dequantization ---
|
||||
//
|
||||
// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack.
|
||||
// Scalar loads from the stack array execute on the scalar pipeline, in parallel
|
||||
// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop.
|
||||
// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10
|
||||
// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15.
|
||||
|
||||
typedef struct {
|
||||
__fp16 v[8] __attribute__((aligned(16)));
|
||||
} mxfp4_scales_t;
|
||||
|
||||
static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) {
|
||||
mxfp4_scales_t s;
|
||||
HVX_Vector v = hvx_vmemu(e8m0_8);
|
||||
HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v));
|
||||
vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112));
|
||||
vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero());
|
||||
vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30));
|
||||
vh = Q6_Vh_vasl_VhR(vh, 10);
|
||||
hvx_vec_store_u(s.v, 16, vh);
|
||||
return s;
|
||||
}
|
||||
|
||||
static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) {
|
||||
return hvx_vec_splat_f16(scales.v[idx]);
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16.
|
||||
static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32,
|
||||
bool upper_nibbles,
|
||||
int sub_blk,
|
||||
const HVX_Vector vlut_cvt,
|
||||
mxfp4_scales_t scales) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc));
|
||||
}
|
||||
|
||||
// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
|
||||
static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
|
||||
bool upper_nibbles,
|
||||
int sub_blk_base,
|
||||
const HVX_Vector vlut_cvt,
|
||||
mxfp4_scales_t scales,
|
||||
HVX_Vector out[4]) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp);
|
||||
|
||||
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
|
||||
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0),
|
||||
mxfp4_extract_splat(scales, sub_blk_base + 1));
|
||||
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2),
|
||||
mxfp4_extract_splat(scales, sub_blk_base + 3));
|
||||
|
||||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
out[0] = v_lo;
|
||||
out[1] = Q6_V_vror_VR(v_lo, 64);
|
||||
out[2] = v_hi;
|
||||
out[3] = Q6_V_vror_VR(v_hi, 64);
|
||||
}
|
||||
|
||||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||||
// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
|
||||
// Output: vtcm_dst in tile-major FP16 layout.
|
||||
@@ -295,11 +385,11 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const int qrow_size = is_q4 ? (k_block / 2) : k_block;
|
||||
const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2);
|
||||
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL)
|
||||
? hvx_vmem(iq4_nl_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut);
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
|
||||
hvx_vmem(q4_0_to_fp16_lut);
|
||||
|
||||
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
|
||||
// Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
|
||||
@@ -312,8 +402,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
int ct = t / n_k_tiles; // column tile index
|
||||
int kt = t % n_k_tiles; // K tile index
|
||||
|
||||
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
// --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) &&
|
||||
((t + 3) / n_k_tiles == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
@@ -351,10 +442,60 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
|
||||
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales
|
||||
|
||||
__fp16 * tile_bases[4];
|
||||
for (int g = 0; g < 4; g++) {
|
||||
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0[4], v1[4];
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0);
|
||||
if (row1 < n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1);
|
||||
} else {
|
||||
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) {
|
||||
(void) *(volatile HVX_Vector *) (tile_bases[g]);
|
||||
}
|
||||
|
||||
t += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Single-tile fallback ---
|
||||
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
if (is_q4) {
|
||||
if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
@@ -382,6 +523,39 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
(void) *(volatile HVX_Vector *)(tile_base);
|
||||
} else if (weight_type == HTP_TYPE_MXFP4) {
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
|
||||
HVX_Vector v1;
|
||||
if (row1 < n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
|
||||
} else {
|
||||
v1 = Q6_V_vzero();
|
||||
}
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
(void) *(volatile HVX_Vector *) (tile_base);
|
||||
} else {
|
||||
// Q8_0
|
||||
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
|
||||
@@ -1455,21 +1629,24 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
|
||||
{
|
||||
qweight_fetch_task_state_t s;
|
||||
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const int blk_start = kk / QK_Q4_0x4x2;
|
||||
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
|
||||
const int full_qrow = is_q4 ? (k / 2) : k;
|
||||
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
|
||||
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
|
||||
const int scale_blk_size =
|
||||
(weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
|
||||
|
||||
s.dst = vtcm_scratch0;
|
||||
s.src = w + nc * row_stride;
|
||||
s.n_rows = n_blk_sz;
|
||||
s.src_stride = row_stride;
|
||||
s.dst_stride = sub_row_stride;
|
||||
s.quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2)) : (blk_start * QK_Q8_0x4x2);
|
||||
s.quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2)) : (nb_sub * QK_Q8_0x4x2);
|
||||
s.scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE;
|
||||
s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE;
|
||||
s.quant_off =
|
||||
(weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2));
|
||||
s.quant_width =
|
||||
(weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2));
|
||||
s.scale_off = full_qrow + blk_start * scale_blk_size;
|
||||
s.scale_width = nb_sub * scale_blk_size;
|
||||
|
||||
// 2D DMA: quants sub-range
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
|
||||
|
||||
@@ -31,6 +31,12 @@ struct htp_context {
|
||||
|
||||
uint32_t opmask;
|
||||
|
||||
// Cached src1 spad position from the last quantize pass.
|
||||
// When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM
|
||||
// at this address; the matmul must read from here instead of recomputing
|
||||
// the offset (which depends on the current op's src0 size).
|
||||
uint8_t * prev_src1_spad;
|
||||
|
||||
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
|
||||
#ifdef HTP_HAS_HMX
|
||||
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
|
||||
|
||||
@@ -1114,14 +1114,12 @@ static void proc_hmx_matmul_req(struct htp_context * ctx,
|
||||
return;
|
||||
}
|
||||
|
||||
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
|
||||
// Other types (e.g. MXFP4) fall back to HVX.
|
||||
// HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
|
||||
// Other types fall back to HVX.
|
||||
{
|
||||
uint32_t wtype = req->src0.type;
|
||||
if (wtype != HTP_TYPE_F16 &&
|
||||
wtype != HTP_TYPE_Q4_0 &&
|
||||
wtype != HTP_TYPE_Q8_0 &&
|
||||
wtype != HTP_TYPE_IQ4_NL) {
|
||||
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL &&
|
||||
wtype != HTP_TYPE_MXFP4) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -60,6 +60,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
|
||||
0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
|
||||
};
|
||||
|
||||
// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue
|
||||
// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
|
||||
static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = {
|
||||
0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0,
|
||||
0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
};
|
||||
|
||||
static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
|
||||
0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
|
||||
0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
@@ -68,6 +78,73 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
};
|
||||
|
||||
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) {
|
||||
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
||||
|
||||
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
|
||||
HVX_Vector v2_3 = vptr[1]; // ...
|
||||
HVX_Vector v4_5 = vptr[2]; // ...
|
||||
HVX_Vector v6_7 = vptr[3]; // ...
|
||||
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
|
||||
|
||||
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
|
||||
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
|
||||
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
|
||||
HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
|
||||
HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
|
||||
HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
|
||||
HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
|
||||
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
|
||||
|
||||
v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
||||
v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
||||
v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
|
||||
v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
|
||||
v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
|
||||
v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
|
||||
v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
|
||||
v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
|
||||
|
||||
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
|
||||
return r;
|
||||
}
|
||||
|
||||
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
|
||||
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2; // 256
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
|
||||
|
||||
HVX_Vector_x8 r;
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nb; i++) {
|
||||
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
||||
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
|
||||
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
|
||||
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
||||
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
||||
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
|
||||
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
|
||||
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
|
||||
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
|
||||
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
|
||||
|
||||
static inline size_t q8x4x2_row_size(uint32_t ne) {
|
||||
@@ -921,6 +998,293 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
||||
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
||||
}
|
||||
|
||||
// ======== IQ4_NL x Q8_0 vec_dot kernels ========
|
||||
// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue).
|
||||
// Scale format is identical to Q4_0 (fp16 scales).
|
||||
|
||||
static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n,
|
||||
float * restrict s0,
|
||||
const void * restrict vx0,
|
||||
const void * restrict vy0) {
|
||||
assert(n % 32 == 0);
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
||||
|
||||
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
||||
|
||||
HVX_Vector r0_sum = Q6_V_vzero();
|
||||
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
||||
|
||||
hvx_vec_store_u(s0, 4, r0_sum);
|
||||
}
|
||||
|
||||
static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n,
|
||||
float * restrict s0,
|
||||
const void * restrict vx0,
|
||||
const void * restrict vx1,
|
||||
const void * restrict vy0) {
|
||||
assert(n % 32 == 0);
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
assert((unsigned long) vx1 % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
||||
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
||||
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
||||
|
||||
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
||||
|
||||
HVX_Vector r0_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_sum = Q6_V_vzero();
|
||||
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
||||
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
||||
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
||||
hvx_vec_store_u(s0, 8, rsum);
|
||||
}
|
||||
|
||||
static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n,
|
||||
float * restrict s0,
|
||||
float * restrict s1,
|
||||
const void * restrict vx0,
|
||||
const void * restrict vx1,
|
||||
const void * restrict vy0,
|
||||
const void * restrict vy1) {
|
||||
assert(n % 32 == 0);
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
assert((unsigned long) vx1 % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
assert((unsigned long) vy1 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;
|
||||
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;
|
||||
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;
|
||||
|
||||
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;
|
||||
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;
|
||||
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;
|
||||
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;
|
||||
|
||||
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
||||
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
||||
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
|
||||
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
|
||||
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
||||
|
||||
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
||||
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
||||
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
||||
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
||||
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
||||
|
||||
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
||||
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
||||
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
||||
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
||||
|
||||
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
||||
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
||||
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
||||
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
||||
|
||||
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
||||
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
||||
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
|
||||
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
|
||||
|
||||
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
||||
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
||||
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
||||
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
||||
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
||||
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
||||
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
||||
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
||||
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
||||
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
||||
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
||||
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
||||
|
||||
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
||||
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
||||
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
||||
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
||||
|
||||
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
||||
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
||||
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
||||
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
||||
}
|
||||
|
||||
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
||||
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
||||
|
||||
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);
|
||||
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);
|
||||
}
|
||||
|
||||
static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
||||
assert(n % 32 == 0); // min sub-block size
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
@@ -2393,6 +2757,12 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t
|
||||
mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
|
||||
mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
|
||||
return 0;
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
mmctx->type = "iq4nlx4x2-f32";
|
||||
mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1;
|
||||
mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1;
|
||||
mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2;
|
||||
return 0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
mmctx->type = "mxfp4x4x2-f32";
|
||||
mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
|
||||
@@ -2556,6 +2926,13 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
// Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
// SKIP_QUANTIZE: Q8 data lives at the address written by the previous
|
||||
// quantize pass. The current op may have a different src0 size (e.g.
|
||||
// IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong.
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
@@ -2659,6 +3036,9 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
||||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
|
||||
@@ -690,7 +690,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
||||
" auto tB = B.slice((int)tgid.x, 0); \n"
|
||||
" \n"
|
||||
" matmul2d< \n"
|
||||
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
||||
" matmul2d_descriptor(16, 16, dynamic_extent), \n"
|
||||
" execution_simdgroups<4>> mm; \n"
|
||||
" \n"
|
||||
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
|
||||
@@ -740,7 +740,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
|
||||
" auto tB = B.slice((int)tgid.x, 0); \n"
|
||||
" \n"
|
||||
" matmul2d< \n"
|
||||
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
||||
" matmul2d_descriptor(16, 16, dynamic_extent), \n"
|
||||
" execution_simdgroups<4>> mm; \n"
|
||||
" \n"
|
||||
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
|
||||
|
||||
@@ -394,6 +394,9 @@ struct ggml_backend_opencl_context {
|
||||
bool fp16_support;
|
||||
bool has_vector_subgroup_broadcast;
|
||||
bool disable_fusion;
|
||||
|
||||
bool adreno_has_large_buffer;
|
||||
bool adreno_use_large_buffer;
|
||||
ggml_cl_compiler_version adreno_cl_compiler_version;
|
||||
|
||||
int adreno_wave_size;
|
||||
@@ -787,6 +790,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
" -cl-mad-enable -cl-unsafe-math-optimizations"
|
||||
" -cl-finite-math-only -cl-fast-relaxed-math";
|
||||
|
||||
if (backend_ctx->adreno_use_large_buffer) {
|
||||
compile_opts += " -qcom-enable-large-buffer ";
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels");
|
||||
|
||||
// add
|
||||
@@ -3020,6 +3027,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
// Check if ext_buffer contains cl_khr_fp16
|
||||
backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL;
|
||||
GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false");
|
||||
// check Adreno large buffer support
|
||||
backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL;
|
||||
|
||||
// fp16 is required
|
||||
if (!backend_ctx->fp16_support) {
|
||||
@@ -3086,6 +3095,18 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n");
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
// determine whether to use large buffer for Adreno
|
||||
backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr &&
|
||||
backend_ctx->gpu_family == GPU_FAMILY::ADRENO;
|
||||
if (backend_ctx->adreno_use_large_buffer) {
|
||||
if (!backend_ctx->adreno_has_large_buffer) {
|
||||
GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n");
|
||||
backend_ctx->adreno_use_large_buffer = false;
|
||||
} else {
|
||||
GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n");
|
||||
}
|
||||
}
|
||||
|
||||
cl_int err;
|
||||
|
||||
// A local ref of cl_context for convenience
|
||||
@@ -5660,6 +5681,11 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b
|
||||
|
||||
cl_int err;
|
||||
cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err);
|
||||
if (err != CL_SUCCESS && backend_ctx->adreno_use_large_buffer) {
|
||||
cl_mem_properties props[] = { 0x41A6 /* CL_LARGE_BUFFER_QCOM */, 1, 0 };
|
||||
mem = clCreateBufferWithProperties(backend_ctx->context, props, CL_MEM_READ_WRITE, size, NULL, &err);
|
||||
}
|
||||
|
||||
if (err != CL_SUCCESS) {
|
||||
GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0);
|
||||
return nullptr;
|
||||
|
||||
@@ -589,8 +589,10 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||
ggml_backend_buffer_t buffer = tensor->buffer;
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
|
||||
result.data = reinterpret_cast<uint64_t>(tensor->data);
|
||||
} else {
|
||||
result.buffer = 0;
|
||||
result.data = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
result.ne[i] = tensor->ne[i];
|
||||
@@ -606,7 +608,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||
}
|
||||
result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
|
||||
result.view_offs = tensor->view_offs;
|
||||
result.data = reinterpret_cast<uint64_t>(tensor->data);
|
||||
|
||||
// Avoid sending uninitialized data over the wire
|
||||
memset(result.name, 0, sizeof(result.name));
|
||||
@@ -1443,9 +1444,11 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
||||
const rpc_tensor * tensor = it_ptr->second;
|
||||
|
||||
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
|
||||
if (result == nullptr || result->buffer == nullptr) {
|
||||
GGML_LOG_ERROR("[%s] invalid tensor: null %s (id=%" PRIu64 ")\n",
|
||||
__func__, result == nullptr ? "tensor" : "buffer", id);
|
||||
if (result == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (result->buffer == nullptr && result->data != nullptr) {
|
||||
GGML_LOG_ERROR("[%s] invalid data ptr", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
tensor_map[id] = result;
|
||||
|
||||
@@ -63,6 +63,7 @@ class TensorNameMap:
|
||||
"transformer.wpe", # gpt2
|
||||
"embeddings.position_embeddings", # bert
|
||||
"wpe", # gpt2
|
||||
"model.embed_positions", # rugpt3xl
|
||||
),
|
||||
|
||||
# Output
|
||||
|
||||
@@ -7578,6 +7578,65 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
if (!layer.ssm_beta_s && layer.ssm_beta) {
|
||||
layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
|
||||
// input scales
|
||||
if (!layer.wq_in_s && layer.wq) {
|
||||
layer.wq_in_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.wk_in_s && layer.wk) {
|
||||
layer.wk_in_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.wv_in_s && layer.wv) {
|
||||
layer.wv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.wo_in_s && layer.wo) {
|
||||
layer.wo_in_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.wqkv_in_s && layer.wqkv) {
|
||||
layer.wqkv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.wqkv_gate_in_s && layer.wqkv_gate) {
|
||||
layer.wqkv_gate_in_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_gate_in_s && layer.ffn_gate) {
|
||||
layer.ffn_gate_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_down_in_s && layer.ffn_down) {
|
||||
layer.ffn_down_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_up_in_s && layer.ffn_up) {
|
||||
layer.ffn_up_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_gate_exps_in_s && layer.ffn_gate_exps) {
|
||||
layer.ffn_gate_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_down_exps_in_s && layer.ffn_down_exps) {
|
||||
layer.ffn_down_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_up_exps_in_s && layer.ffn_up_exps) {
|
||||
layer.ffn_up_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_gate_shexp_in_s && layer.ffn_gate_shexp) {
|
||||
layer.ffn_gate_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_down_shexp_in_s && layer.ffn_down_shexp) {
|
||||
layer.ffn_down_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_up_shexp_in_s && layer.ffn_up_shexp) {
|
||||
layer.ffn_up_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ssm_in_in_s && layer.ssm_in) {
|
||||
layer.ssm_in_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ssm_out_in_s && layer.ssm_out) {
|
||||
layer.ssm_out_in_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ssm_alpha_in_s && layer.ssm_alpha) {
|
||||
layer.ssm_alpha_in_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ssm_beta_in_s && layer.ssm_beta) {
|
||||
layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -414,6 +414,27 @@ struct llama_layer {
|
||||
struct ggml_tensor * ssm_alpha_s = nullptr;
|
||||
struct ggml_tensor * ssm_beta_s = nullptr;
|
||||
|
||||
// input scales
|
||||
struct ggml_tensor * wq_in_s = nullptr;
|
||||
struct ggml_tensor * wk_in_s = nullptr;
|
||||
struct ggml_tensor * wv_in_s = nullptr;
|
||||
struct ggml_tensor * wo_in_s = nullptr;
|
||||
struct ggml_tensor * wqkv_in_s = nullptr;
|
||||
struct ggml_tensor * wqkv_gate_in_s = nullptr;
|
||||
struct ggml_tensor * ffn_gate_in_s = nullptr;
|
||||
struct ggml_tensor * ffn_up_in_s = nullptr;
|
||||
struct ggml_tensor * ffn_down_in_s = nullptr;
|
||||
struct ggml_tensor * ffn_gate_exps_in_s = nullptr;
|
||||
struct ggml_tensor * ffn_down_exps_in_s = nullptr;
|
||||
struct ggml_tensor * ffn_up_exps_in_s = nullptr;
|
||||
struct ggml_tensor * ffn_gate_shexp_in_s= nullptr;
|
||||
struct ggml_tensor * ffn_up_shexp_in_s = nullptr;
|
||||
struct ggml_tensor * ffn_down_shexp_in_s= nullptr;
|
||||
struct ggml_tensor * ssm_in_in_s = nullptr;
|
||||
struct ggml_tensor * ssm_out_in_s = nullptr;
|
||||
struct ggml_tensor * ssm_alpha_in_s = nullptr;
|
||||
struct ggml_tensor * ssm_beta_in_s = nullptr;
|
||||
|
||||
// altup & laurel
|
||||
struct ggml_tensor * per_layer_inp_gate = nullptr;
|
||||
struct ggml_tensor * per_layer_proj = nullptr;
|
||||
|
||||
+4
-1
@@ -345,9 +345,12 @@ static bool tensor_allows_quantization(const llama_model_quantize_params * param
|
||||
|
||||
// do not quantize specific multimodal tensors
|
||||
quantize &= name.find(".position_embd") == std::string::npos;
|
||||
quantize &= name.find("sam.patch_embd") == std::string::npos;
|
||||
quantize &= name.find("sam.pos_embd") == std::string::npos;
|
||||
quantize &= name.find("sam.neck.") == std::string::npos;
|
||||
quantize &= name.find("sam.net_") == std::string::npos;
|
||||
quantize &= name.find(".rel_pos") == std::string::npos;
|
||||
quantize &= name.find(".patch_embd") == std::string::npos;
|
||||
quantize &= name.find(".patch_merger") == std::string::npos;
|
||||
|
||||
return quantize;
|
||||
}
|
||||
|
||||
+23
-14
@@ -4823,28 +4823,33 @@ struct test_conv_transpose_1d : public test_case {
|
||||
|
||||
// GGML_OP_CONV_TRANSPOSE_2D
|
||||
struct test_conv_transpose_2d : public test_case {
|
||||
// Dimensions
|
||||
const std::array<int64_t, 4> ne_input;
|
||||
const std::array<int64_t, 4> ne_kernel;
|
||||
const int stride;
|
||||
// Types
|
||||
const ggml_type kernel_type;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(ne_input, ne_kernel, stride);
|
||||
return VARS_TO_STR4(kernel_type, ne_input, ne_kernel, stride);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 5e-4; // The default 1e-7 is too small for Vulkan.
|
||||
}
|
||||
|
||||
test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
|
||||
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
|
||||
int stride = 1)
|
||||
: ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}
|
||||
test_conv_transpose_2d(
|
||||
std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
|
||||
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
|
||||
int stride = 1,
|
||||
ggml_type kernel_type = GGML_TYPE_F16
|
||||
) : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), kernel_type(kernel_type) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
|
||||
ggml_set_name(input, "input");
|
||||
|
||||
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
|
||||
ggml_tensor * kernel = ggml_new_tensor(ctx, kernel_type, 4, ne_kernel.data());
|
||||
ggml_set_name(kernel, "kernel");
|
||||
|
||||
ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
|
||||
@@ -7279,7 +7284,7 @@ static const ggml_type all_types[] = {
|
||||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_MXFP4,
|
||||
GGML_TYPE_MXFP4, GGML_TYPE_NVFP4,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
@@ -7295,7 +7300,7 @@ static const ggml_type base_types[] = {
|
||||
GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_Q4_1, // for I8MM tests
|
||||
GGML_TYPE_Q4_K,
|
||||
GGML_TYPE_MXFP4, // TODO: or "other"
|
||||
GGML_TYPE_MXFP4, GGML_TYPE_NVFP4, // TODO: or "other"
|
||||
GGML_TYPE_IQ2_XXS
|
||||
};
|
||||
|
||||
@@ -7704,9 +7709,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
|
||||
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
||||
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({129, 63, 35, 1}, {3, 3, 48, 35}, 1));
|
||||
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1, kernel_type));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({129, 63, 35, 1}, {3, 3, 48, 35}, 1, kernel_type));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
|
||||
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
|
||||
@@ -8892,9 +8899,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
|
||||
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
|
||||
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1, kernel_type));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1, kernel_type));
|
||||
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
|
||||
|
||||
|
||||
@@ -1330,7 +1330,7 @@ static void test_nemotron_reasoning_detection(testing & t) {
|
||||
analysis.analyze_template(tmpl);
|
||||
|
||||
// Check reasoning markers
|
||||
t.assert_equal("reasoning_start should be '<think>'", "<think>", analysis.reasoning.start);
|
||||
t.assert_equal("reasoning_start should be '<think>\\n'", "<think>\n", analysis.reasoning.start);
|
||||
t.assert_equal("reasoning_end should be '</think>'", "</think>", analysis.reasoning.end);
|
||||
|
||||
// Check reasoning mode detection
|
||||
|
||||
+332
-81
@@ -805,7 +805,8 @@ struct peg_test_case {
|
||||
common_chat_templates_inputs params;
|
||||
std::string input;
|
||||
common_chat_msg expect;
|
||||
bool is_partial = false;
|
||||
bool is_partial = false;
|
||||
bool expect_reconstruction = false;
|
||||
};
|
||||
|
||||
struct make_peg_parser {
|
||||
@@ -828,6 +829,12 @@ struct make_peg_parser {
|
||||
}
|
||||
};
|
||||
|
||||
// Global template filter for --template flag
|
||||
static std::string g_template_filter;
|
||||
|
||||
// When true, run reconstruction test on every non-partial test and report results
|
||||
static bool g_force_reconstruction_test = false;
|
||||
|
||||
static void test_peg_parser(common_chat_templates * tmpls,
|
||||
const std::function<void(peg_test_case &)> & init,
|
||||
bool detailed_debug) {
|
||||
@@ -936,75 +943,158 @@ static void test_peg_parser(common_chat_templates * tmpls,
|
||||
throw std::runtime_error("Failed to build grammar: " + parser.params_.grammar);
|
||||
}
|
||||
|
||||
// Find the earliest trigger position to determine the constrained portion
|
||||
auto earliest_trigger_pos = std::string::npos;
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
size_t pos = std::string::npos;
|
||||
std::smatch match;
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
pos = tc.input.find(word);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
const auto & pattern = std::regex(trigger.value);
|
||||
if (std::regex_search(tc.input, match, pattern)) {
|
||||
pos = match.position(pattern.mark_count());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
if (std::regex_match(tc.input, match, std::regex(pattern))) {
|
||||
auto mpos = std::string::npos;
|
||||
for (size_t i = 1; i < match.size(); ++i) {
|
||||
if (match[i].length() > 0) {
|
||||
mpos = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (mpos == std::string::npos) {
|
||||
mpos = match.position(0);
|
||||
}
|
||||
pos = mpos;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unknown trigger type");
|
||||
}
|
||||
if (pos != std::string::npos) {
|
||||
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
||||
earliest_trigger_pos = pos;
|
||||
// In production, grammar triggers match against the full generated text
|
||||
// including the generation prompt. All positions are in full_input coordinates.
|
||||
const auto & gen_prompt = parser.params_.generation_prompt;
|
||||
std::string full_input = gen_prompt + tc.input;
|
||||
|
||||
// Determine whether the reasoning-budget sampler path applies: tool-call grammar
|
||||
// with all WORD triggers and thinking tags present. In production, the reasoning
|
||||
// budget sampler inhibits grammar application while inside thinking blocks —
|
||||
// triggers inside <think>...</think> are suppressed.
|
||||
bool use_reasoning_budget_path = false;
|
||||
if (parser.params_.grammar_lazy && !parser.params_.thinking_end_tag.empty()) {
|
||||
use_reasoning_budget_path = true;
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
if (trigger.type != COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||
use_reasoning_budget_path = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the constrained portion of input to test against grammar
|
||||
std::string constrained = tc.input;
|
||||
// Find the earliest trigger position to determine the constrained portion
|
||||
auto earliest_trigger_pos = std::string::npos;
|
||||
|
||||
if (use_reasoning_budget_path) {
|
||||
// Reasoning-budget path: simulate thinking-aware trigger detection.
|
||||
// Walk through full_input tracking thinking state; only match triggers
|
||||
// when outside thinking blocks.
|
||||
const auto & think_start = parser.params_.thinking_start_tag;
|
||||
const auto & think_end = parser.params_.thinking_end_tag;
|
||||
|
||||
bool in_thinking = false;
|
||||
for (size_t i = 0; i < full_input.size(); ++i) {
|
||||
if (!in_thinking && !think_start.empty()
|
||||
&& full_input.compare(i, think_start.size(), think_start) == 0) {
|
||||
in_thinking = true;
|
||||
i += think_start.size() - 1;
|
||||
continue;
|
||||
}
|
||||
if (in_thinking && full_input.compare(i, think_end.size(), think_end) == 0) {
|
||||
in_thinking = false;
|
||||
i += think_end.size() - 1;
|
||||
continue;
|
||||
}
|
||||
if (in_thinking) {
|
||||
continue;
|
||||
}
|
||||
// Outside thinking — check if any trigger word starts here
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
if (full_input.compare(i, trigger.value.size(), trigger.value) == 0) {
|
||||
if (earliest_trigger_pos == std::string::npos || i < earliest_trigger_pos) {
|
||||
earliest_trigger_pos = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (earliest_trigger_pos != std::string::npos) {
|
||||
break; // found the earliest
|
||||
}
|
||||
}
|
||||
|
||||
// If the reasoning-budget path found no trigger outside thinking but the test
|
||||
// expects tool calls, this template nests tool calls inside thinking
|
||||
// blocks (e.g. Kimi). Fall back to the legacy path for this case.
|
||||
if (earliest_trigger_pos == std::string::npos && !tc.expect.tool_calls.empty()) {
|
||||
use_reasoning_budget_path = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!use_reasoning_budget_path) {
|
||||
// Legacy path: find triggers without thinking-awareness
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
size_t pos = std::string::npos;
|
||||
std::smatch match;
|
||||
switch (trigger.type) {
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
pos = full_input.find(word);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
const auto & compiled = std::regex(trigger.value);
|
||||
if (std::regex_search(full_input, match, compiled)) {
|
||||
pos = match.position(compiled.mark_count());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
// In production, PATTERN_FULL triggers are checked against
|
||||
// the text generated so far, growing token by token. Simulate
|
||||
// by trying every prefix of full_input.
|
||||
const auto & compiled = std::regex(trigger.value);
|
||||
for (size_t end = gen_prompt.size(); end <= full_input.size(); ++end) {
|
||||
std::string prefix = full_input.substr(0, end);
|
||||
if (std::regex_match(prefix, match, compiled)) {
|
||||
pos = std::string::npos;
|
||||
for (size_t gi = 1; gi < match.size(); ++gi) {
|
||||
if (match[gi].length() > 0) {
|
||||
pos = match.position(gi);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pos == std::string::npos) {
|
||||
pos = match.position(0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unknown trigger type");
|
||||
}
|
||||
if (pos != std::string::npos) {
|
||||
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
||||
earliest_trigger_pos = pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the test expects tool calls and the grammar is lazy, the trigger must fire.
|
||||
// Otherwise the grammar would never activate in production and tool calls wouldn't
|
||||
// be constrained. A silent skip here would hide broken triggers.
|
||||
if (parser.params_.grammar_lazy && !tc.expect.tool_calls.empty() && !tc.is_partial
|
||||
&& earliest_trigger_pos == std::string::npos) {
|
||||
std::string trigger_desc;
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
trigger_desc += "\n [type=" + std::to_string(trigger.type) + "] " + trigger.value;
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"Grammar trigger did not fire, but test expects tool calls (lazy grammar).\n"
|
||||
">>> Input: " + full_input + "\n"
|
||||
">>> Triggers (" + std::to_string(parser.params_.grammar_triggers.size()) + "):" + trigger_desc);
|
||||
}
|
||||
|
||||
// Determine the constrained portion of input to test against grammar.
|
||||
// If the trigger position falls inside the generation prompt, the grammar
|
||||
// sampler was already active before model output began — constrain from the
|
||||
// start of the model output (i.e. tc.input).
|
||||
std::string constrained = full_input;
|
||||
bool grammar_triggered = false;
|
||||
if (earliest_trigger_pos != std::string::npos) {
|
||||
constrained = tc.input.substr(earliest_trigger_pos);
|
||||
auto constrain_from = std::max(earliest_trigger_pos, gen_prompt.size());
|
||||
constrained = full_input.substr(constrain_from);
|
||||
grammar_triggered = true;
|
||||
} else if (!parser.params_.grammar_lazy) {
|
||||
// For non-lazy grammars, the entire input should match
|
||||
grammar_triggered = true;
|
||||
}
|
||||
|
||||
// For non-lazy grammars, prepend reasoning prefill to grammar input, just like
|
||||
// PEG parsing does. The grammar includes the full reasoning pattern (e.g. optional
|
||||
// <think>...</think>), but the model output may start mid-reasoning if the template
|
||||
// already placed the opening tag in the prompt.
|
||||
// For lazy grammars, the grammar only activates from the trigger position, so the
|
||||
// reasoning prefill is irrelevant — reasoning is handled by the PEG parser.
|
||||
if (!parser.params_.generation_prompt.empty() && earliest_trigger_pos == std::string::npos) {
|
||||
constrained = parser.params_.generation_prompt + constrained;
|
||||
}
|
||||
|
||||
// Test the constrained portion against the grammar
|
||||
if (grammar_triggered && !tc.is_partial) {
|
||||
auto result = match_string_detailed(constrained, grammar.get());
|
||||
@@ -1036,10 +1126,57 @@ static void test_peg_parser(common_chat_templates * tmpls,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Global template filter for --template flag
|
||||
static std::string g_template_filter;
|
||||
// Reconstruction test: verify that appending the parsed message to the original
|
||||
// messages and re-rendering the template (without generation prompt) reproduces
|
||||
// the original prompt + input exactly, or as a proper prefix (the template may
|
||||
// append end-of-turn tokens after the assistant message).
|
||||
if ((tc.expect_reconstruction || g_force_reconstruction_test) && !tc.is_partial) {
|
||||
// Start from tc.expect but copy tool call arguments from the actual parser
|
||||
// output, which preserves original JSON formatting (e.g. {"arg1":1} vs {"arg1": 1}).
|
||||
auto reconstruction_msg = tc.expect;
|
||||
auto parsed_msg = parser.parse(tc.input, false);
|
||||
for (size_t i = 0; i < reconstruction_msg.tool_calls.size() && i < parsed_msg.tool_calls.size(); i++) {
|
||||
reconstruction_msg.tool_calls[i].arguments = parsed_msg.tool_calls[i].arguments;
|
||||
}
|
||||
common_chat_templates_inputs reconstruction_inputs = tc.params;
|
||||
reconstruction_inputs.messages.push_back(reconstruction_msg);
|
||||
reconstruction_inputs.add_generation_prompt = false;
|
||||
|
||||
auto reconstruction_params = common_chat_templates_apply(tmpls, reconstruction_inputs);
|
||||
std::string expected_text = parser.params_.prompt + tc.input;
|
||||
bool match = reconstruction_params.prompt == expected_text ||
|
||||
(reconstruction_params.prompt.size() > expected_text.size() &&
|
||||
reconstruction_params.prompt.compare(0, expected_text.size(), expected_text) == 0);
|
||||
if (!match && g_force_reconstruction_test && !tc.expect_reconstruction) {
|
||||
// In forced mode, report mismatch but don't fail
|
||||
// Find the first difference position
|
||||
size_t diff_pos = 0;
|
||||
size_t min_len = std::min(expected_text.size(), reconstruction_params.prompt.size());
|
||||
while (diff_pos < min_len && expected_text[diff_pos] == reconstruction_params.prompt[diff_pos]) {
|
||||
diff_pos++;
|
||||
}
|
||||
size_t ctx_start = diff_pos > 60 ? diff_pos - 60 : 0;
|
||||
size_t ctx_end_e = std::min(expected_text.size(), diff_pos + 40);
|
||||
size_t ctx_end_r = std::min(reconstruction_params.prompt.size(), diff_pos + 40);
|
||||
LOG_ERR("\x1b[31m[RECONSTRUCTION FAIL]\x1b[0m "
|
||||
"first diff at byte %zu (expected len=%zu, reconstructed len=%zu)\n"
|
||||
" expected: ...%s...\n"
|
||||
" reconstructed: ...%s...\n",
|
||||
diff_pos, expected_text.size(), reconstruction_params.prompt.size(),
|
||||
expected_text.substr(ctx_start, ctx_end_e - ctx_start).c_str(),
|
||||
reconstruction_params.prompt.substr(ctx_start, ctx_end_r - ctx_start).c_str());
|
||||
} else if (!match) {
|
||||
std::string error_msg =
|
||||
"Reconstruction mismatch:\n\n"
|
||||
">>> Expected (prompt + input):\n" + expected_text +
|
||||
"\n\n>>> Reconstructed:\n" + reconstruction_params.prompt;
|
||||
throw std::runtime_error(error_msg);
|
||||
} else if (g_force_reconstruction_test) {
|
||||
LOG_INF("\x1b[32m[RECONSTRUCTION OK]\x1b[0m\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fluent builder for PEG parser tests
|
||||
class peg_test_builder;
|
||||
@@ -1099,6 +1236,11 @@ class peg_test_builder {
|
||||
return *this;
|
||||
}
|
||||
|
||||
peg_test_builder & expect_reconstruction(bool val = true) {
|
||||
tc_.expect_reconstruction = val;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Expect setters
|
||||
peg_test_builder & expect(const common_chat_msg & msg) {
|
||||
tc_.expect = msg;
|
||||
@@ -1272,16 +1414,18 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Ministral-3-14B-Reasoning-2512
|
||||
auto tst = peg_tester("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
|
||||
tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
|
||||
.expect_content("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.expect(message_assist_thoughts)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})")
|
||||
@@ -1311,6 +1455,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -1323,6 +1468,20 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_reasoning("I need to output the invoice details in JSON")
|
||||
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
|
||||
.run();
|
||||
|
||||
// fake tool call marker in reasoning
|
||||
tst.test(
|
||||
"[THINK]Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more[/THINK]"
|
||||
R"([TOOL_CALLS]special_function[ARGS]{"arg1": 1})")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.tools({ special_function_tool })
|
||||
.expect_reasoning("Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more")
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1425,6 +1584,50 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_reasoning("I need to output the invoice details in JSON")
|
||||
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
|
||||
.run();
|
||||
|
||||
// tool call segment in reasoning
|
||||
tst.test(
|
||||
"Let's call a tool: <tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Not the real call!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call></think>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>"
|
||||
)
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({
|
||||
python_tool
|
||||
})
|
||||
.expect_reasoning("Let's call a tool: <tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Not the real call!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.expect_tool_calls({
|
||||
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1481,9 +1684,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Google Gemma 2 2B - does not support tool calling
|
||||
auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja");
|
||||
|
||||
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run();
|
||||
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).expect_reconstruction().run();
|
||||
|
||||
tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).run();
|
||||
tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
{
|
||||
@@ -1526,7 +1729,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Test simple content-only template
|
||||
auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
}
|
||||
{
|
||||
// IBM Granite (reasoning and tool calling model)
|
||||
@@ -1638,7 +1841,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Qwen3-Coder (tool calling with XML-style format)
|
||||
auto tst = peg_tester("models/templates/Qwen3-Coder.jinja", detailed_debug);
|
||||
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
@@ -1650,6 +1853,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
"</tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -1678,6 +1882,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test with code content (multiline)
|
||||
@@ -1698,6 +1903,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test with code content (asian unicode chars)
|
||||
@@ -1715,6 +1921,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "python", "{\"code\": \"格\"}", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test with HTML tag content
|
||||
@@ -1736,6 +1943,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "html", "{\"markup\": \"<html>\\n <head>\\n <title>Hello!</title>\\n </head>\\n</html>\"}", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test with TODO list (array of objects)
|
||||
@@ -1753,6 +1961,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "todo_list", "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test flexible optional argument ordering (2 required + 4 optional, reversed optional order)
|
||||
@@ -1769,6 +1978,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "tool_2req_4opt", R"({"req1": "hello", "req2": 42, "opt4": 100, "opt2": 200})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test flexible optional argument ordering (2 required + 5 optional, reversed optional order)
|
||||
@@ -1786,6 +1996,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "tool_2req_5opt", R"({"req1": "world", "req2": 7, "opt5": "last", "opt3": "middle", "opt1": "first"})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Test flexible optional argument ordering (2 required + 5 optional, all 5 in shuffled order)
|
||||
@@ -1805,6 +2016,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.expect_tool_calls({
|
||||
{ "tool_2req_5opt", R"({"req1": "test", "req2": 99, "opt3": "c", "opt1": "a", "opt5": "e", "opt4": 4, "opt2": 2})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
{
|
||||
@@ -1885,6 +2097,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
tst.test("Hello, world!\nWhat's up?")
|
||||
.enable_thinking(false)
|
||||
.expect(message_assist)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Reasoning with content (forced-open mode - input starts after <think>)
|
||||
@@ -1892,6 +2105,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.expect(message_assist_thoughts)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Tool call without reasoning
|
||||
@@ -1902,6 +2116,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.enable_thinking(false)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Tool call with reasoning (forced-open mode)
|
||||
@@ -1914,6 +2129,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_thoughts)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -1933,6 +2149,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// #20650: tool with no required args, model emits <tool_call>name</tool_call> with no arg tags.
|
||||
@@ -1950,6 +2167,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.tools({ no_args_tool })
|
||||
.expect_reasoning("Let me read the diff content.")
|
||||
.expect_tool_calls({{ "read_file_diff_md", "{}", {} }})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
}
|
||||
@@ -2208,22 +2426,24 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
|
||||
// Kimi-K2 old template
|
||||
auto tst = peg_tester("models/templates/moonshotai-Kimi-K2.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test(
|
||||
"<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>"
|
||||
"{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(kimi_id_special_func_tool_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Kimi-K2-Instruct
|
||||
auto tst2 = peg_tester("models/templates/Kimi-K2-Instruct.jinja", detailed_debug);
|
||||
tst2.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst2.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst2.test(
|
||||
"<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>"
|
||||
"{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(kimi_id_special_func_tool_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
@@ -2297,6 +2517,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.tools({ empty_args_tool })
|
||||
.expect(simple_assist_msg("", "", "empty_args", "{}"))
|
||||
.run();
|
||||
|
||||
// fake tool call marker in reasoning
|
||||
tst.test(
|
||||
"<think>Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm</think>"
|
||||
"<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect_reasoning("Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm")
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
})
|
||||
.run();
|
||||
}
|
||||
|
||||
// Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format
|
||||
@@ -2306,6 +2539,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
tst.test("<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
@@ -2314,7 +2548,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{
|
||||
auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug);
|
||||
tst.test(
|
||||
"</think><minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"<minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
@@ -2364,37 +2598,41 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// mistralai-Mistral-Nemo-Instruct-2407.jinja
|
||||
{
|
||||
auto tst = peg_tester("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_id)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
{
|
||||
auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.1.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("<function=special_function>{\"arg1\": 1}</function>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
// Functionary v3.2 - recipient-based format: >>>recipient\n{content}
|
||||
{
|
||||
auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.2.jinja", detailed_debug);
|
||||
tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("special_function\n{\"arg1\": 1}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
// FireFunction
|
||||
{
|
||||
auto tst = peg_tester("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test(" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
@@ -2455,10 +2693,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
{ "models/templates/MiMo-VL.jinja", "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja",
|
||||
"models/templates/Qwen-Qwen2.5-7B-Instruct.jinja" }) {
|
||||
auto tst = peg_tester(path, detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n</tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
@@ -2481,6 +2720,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.expect(simple_assist_msg("Hello, world!\nWhat's up?", "Here are my reasoning steps:\nI'm\nthinking"))
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// Reasoning + Tool calls
|
||||
@@ -2497,42 +2737,45 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Mistral Small 3.2 - FUNC_BRACKET_TAG format: [TOOL_CALLS]func_name[CALL_ID]id[ARGS]{...}
|
||||
{
|
||||
auto tst = peg_tester("models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("[TOOL_CALLS]special_function[CALL_ID]123456789[ARGS]{\"arg1\": 1}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_id)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
// Devstral
|
||||
{
|
||||
auto tst = peg_tester("models/templates/unsloth-mistral-Devstral-Small-2507.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
tst.test("Hello, world!\nWhat's up?[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_content)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
// Llama 3.1
|
||||
auto tst = peg_tester("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
{
|
||||
// Llama 3.2
|
||||
auto tst = peg_tester("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
{
|
||||
// Llama 3.3
|
||||
auto tst = peg_tester("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
// GPT-OSS format tests
|
||||
@@ -2836,10 +3079,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// GigaChat V3
|
||||
{
|
||||
auto tst = peg_tester("models/templates/GigaChat3-10B-A1.8B.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("<|message_sep|>\n\nfunction call<|role_sep|>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -2848,16 +3092,18 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_content)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
|
||||
// GigaChat V3.1
|
||||
{
|
||||
auto tst = peg_tester("models/templates/GigaChat3.1-10B-A1.8B.jinja", detailed_debug);
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run();
|
||||
tst.test("<|function_call|>{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
@@ -2866,6 +3112,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_content)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
}
|
||||
}
|
||||
@@ -3002,6 +3249,10 @@ int main(int argc, char ** argv) {
|
||||
detailed_debug = true;
|
||||
common_log_set_verbosity_thold(999);
|
||||
}
|
||||
if (arg == "--force-reconstruction-test") {
|
||||
g_force_reconstruction_test = true;
|
||||
only_run_filtered = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (only_run_filtered) {
|
||||
|
||||
@@ -61,8 +61,6 @@ static void test_reasoning_budget(
|
||||
|
||||
// Feed the sequence and track when forcing occurs
|
||||
for (size_t i = 0; i < sequence.size(); i++) {
|
||||
llama_sampler_accept(sampler, sequence[i]);
|
||||
|
||||
// Check if we're in forcing state by applying and seeing if logits are modified
|
||||
cur_p.selected = -1;
|
||||
for (size_t j = 0; j < cur.size(); j++) {
|
||||
@@ -81,6 +79,8 @@ static void test_reasoning_budget(
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_accept(sampler, sequence[i]);
|
||||
|
||||
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
|
||||
|
||||
if (finite_count == 1) {
|
||||
@@ -167,9 +167,9 @@ int main(void) {
|
||||
}
|
||||
|
||||
// Test 2: Budget exhausted, forcing should occur
|
||||
// Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING
|
||||
// Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state)
|
||||
// At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE
|
||||
// Flow: i=0 apply()->passthrough, accept(100)->COUNTING; i=1 accept(50)->remaining=1
|
||||
// i=2 accept(51)->remaining=0->FORCING; i=3 apply() forces token[0]; i=4 apply() forces token[1]
|
||||
// At i=4, accept() advances force_pos to 2 which equals forced_tokens.size(), so state becomes DONE
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
@@ -179,13 +179,12 @@ int main(void) {
|
||||
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_IDLE,
|
||||
2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces)
|
||||
3); // forcing continues through i=3 (at i=4 state becomes DONE)
|
||||
3, // forcing starts at i=3 (accept at i=2 depletes budget, apply at i=3 forces)
|
||||
4); // forcing continues through i=4 (accept at i=4 transitions to DONE)
|
||||
}
|
||||
|
||||
// Test 3: Activate immediately with budget=0, forcing should start right away
|
||||
// Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough)
|
||||
// This test needs start token to be in the sequence or use activate_immediately with start token present
|
||||
// Flow: init promotes COUNTING+budget=0 to FORCING, so apply() sees FORCING at i=0
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
@@ -195,8 +194,8 @@ int main(void) {
|
||||
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
|
||||
0, // budget of 0 tokens
|
||||
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
|
||||
0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING)
|
||||
1); // forcing continues through i=1 (at i=2 state becomes DONE)
|
||||
0, // forcing starts at i=0 (initialized in FORCING, apply forces immediately)
|
||||
1); // forcing continues through i=1 (accept at i=1 transitions to DONE)
|
||||
}
|
||||
|
||||
// Test 4: No start/end tokens configured - passthrough (no forcing)
|
||||
@@ -214,7 +213,7 @@ int main(void) {
|
||||
|
||||
// Test 5: Activate immediately with budget > 0, count down then force
|
||||
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
|
||||
// So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0)
|
||||
// Forcing starts at i=2 (apply sees FORCING after accept at i=1 transitioned)
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
@@ -224,8 +223,8 @@ int main(void) {
|
||||
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_COUNTING,
|
||||
1, // forcing starts at i=1 (after 2 accepts deplete budget)
|
||||
2); // forcing continues through i=2
|
||||
2, // forcing starts at i=2 (after 2 accepts deplete budget, apply at i=2 forces)
|
||||
3); // forcing continues through i=3
|
||||
}
|
||||
|
||||
printf("OK (5 tests passed)\n");
|
||||
|
||||
+81
-18
@@ -100,7 +100,7 @@ struct cli_context {
|
||||
}
|
||||
|
||||
// reasoning budget sampler
|
||||
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
|
||||
if (!chat_params.thinking_end_tag.empty()) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(
|
||||
llama_get_model(ctx_server.get_llama_context()));
|
||||
|
||||
@@ -224,10 +224,11 @@ struct cli_context {
|
||||
};
|
||||
|
||||
// TODO?: Make this reusable, enums, docs
|
||||
static const std::array<const std::string, 6> cmds = {
|
||||
static const std::array<const std::string, 7> cmds = {
|
||||
"/audio ",
|
||||
"/clear",
|
||||
"/exit",
|
||||
"/glob ",
|
||||
"/image ",
|
||||
"/read ",
|
||||
"/regen",
|
||||
@@ -258,7 +259,7 @@ static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std:
|
||||
}
|
||||
}
|
||||
|
||||
if (!cmd.empty() && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
|
||||
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
|
||||
auto cur_dir = std::filesystem::current_path();
|
||||
@@ -339,6 +340,8 @@ static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std:
|
||||
return matches;
|
||||
}
|
||||
|
||||
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
@@ -430,7 +433,8 @@ int main(int argc, char ** argv) {
|
||||
console::log(" /exit or Ctrl+C stop or exit\n");
|
||||
console::log(" /regen regenerate the last response\n");
|
||||
console::log(" /clear clear the chat history\n");
|
||||
console::log(" /read add a text file\n");
|
||||
console::log(" /read <file> add a text file\n");
|
||||
console::log(" /glob <pattern> add text files using globbing pattern\n");
|
||||
if (inf.has_inp_image) {
|
||||
console::log(" /image <file> add an image file\n");
|
||||
}
|
||||
@@ -441,6 +445,27 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// interactive loop
|
||||
std::string cur_msg;
|
||||
|
||||
auto add_text_file = [&](const std::string & fname) -> bool {
|
||||
std::string marker = ctx_cli.load_input_file(fname, false);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
return false;
|
||||
}
|
||||
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
|
||||
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
|
||||
cur_msg += fname;
|
||||
cur_msg.push_back('\n');
|
||||
} else {
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded text from '%s'\n", fname.c_str());
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
std::string buffer;
|
||||
console::set_display(DISPLAY_TYPE_USER_INPUT);
|
||||
@@ -525,22 +550,60 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
std::string marker = ctx_cli.load_input_file(fname, false);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
continue;
|
||||
add_text_file(fname);
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/glob ")) {
|
||||
std::error_code ec;
|
||||
size_t count = 0;
|
||||
auto curdir = std::filesystem::current_path();
|
||||
std::string pattern = string_strip(buffer.substr(6));
|
||||
std::filesystem::path rel_path;
|
||||
|
||||
auto startglob = pattern.find_first_of("![*?");
|
||||
if (startglob != std::string::npos && startglob != 0) {
|
||||
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
||||
if (endpath != std::string::npos) {
|
||||
std::string rel_pattern = pattern.substr(0, endpath);
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(rel_pattern, "~")) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
rel_pattern = std::string(home) + rel_pattern.substr(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rel_path = rel_pattern;
|
||||
pattern.erase(0, endpath + 1);
|
||||
curdir /= rel_path;
|
||||
}
|
||||
}
|
||||
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
|
||||
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
|
||||
cur_msg += fname;
|
||||
cur_msg.push_back('\n');
|
||||
} else {
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
|
||||
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
||||
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(pattern, rel)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!add_text_file((rel_path / rel).string())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
||||
console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS);
|
||||
break;
|
||||
}
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded text from '%s'\n", fname.c_str());
|
||||
continue;
|
||||
} else {
|
||||
// not a command
|
||||
|
||||
@@ -146,13 +146,19 @@ int main(int argc, char ** argv) {
|
||||
|
||||
ctx = llama_init->context();
|
||||
model = llama_init->model();
|
||||
smpl = llama_init->sampler(0);
|
||||
|
||||
if (ctx == NULL) {
|
||||
LOG_ERR("%s: error: unable to create context\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: error: unable to load model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
smpl = llama_init->sampler(0);
|
||||
|
||||
llama_memory_t mem = llama_get_memory(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
@@ -695,7 +701,7 @@ int main(int argc, char ** argv) {
|
||||
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) {
|
||||
return 1;
|
||||
}
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
||||
n_session_consumed = session_tokens.size();
|
||||
session_do_save = false;
|
||||
|
||||
|
||||
@@ -143,11 +143,20 @@ static void compute_statistics(std::vector<tensor_statistics> & tstats, const st
|
||||
activations.reserve(e.values.size());
|
||||
|
||||
for (int i = 0; i < n_mat; ++i) {
|
||||
if (e.counts[i] == 0) {
|
||||
LOG_DBG("%s: skipping tensor %s due to zero count at index %d\n", __func__, name.c_str(), i);
|
||||
continue;
|
||||
}
|
||||
for (int j = 0; j < row_size; ++j) {
|
||||
activations.push_back(e.values[i*row_size + j] / e.counts[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (activations.empty()) {
|
||||
LOG_ERR("%s: all counts are zero for tensor %s, skipping statistics computation\n", __func__, name.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
const float act_total = std::accumulate(activations.begin(), activations.end(), 0.0f);
|
||||
const float act_max = *std::max_element(activations.begin(), activations.end());
|
||||
const float act_min = *std::min_element(activations.begin(), activations.end());
|
||||
@@ -1142,10 +1151,12 @@ static bool show_statistics(const common_params & params) {
|
||||
blk = -1; // not a block layer
|
||||
}
|
||||
|
||||
const float entropy_norm = (tstat.elements > 0) ? 100.0f * (tstat.entropy / std::log2(tstat.elements)) : 0.0f;
|
||||
|
||||
LOG_INF("%5s\t%-20s\t%10.2f\t%8.4f\t%11.4f\t%6.2f\t%6.2f\t%8.2f%%\t%6d\t%10.4f\t%6.2f%%\t%10.2f%%\t%8.4f\n",
|
||||
layer.c_str(), name.c_str(), tstat.total_sqract, tstat.min_sqract, tstat.max_sqract, tstat.mean_sqract,
|
||||
tstat.stddev, tstat.active * 100.0f, tstat.elements, tstat.entropy,
|
||||
100.0f * (tstat.entropy / std::log2(tstat.elements)), 100.0f * tstat.zd, tstat.cossim);
|
||||
entropy_norm, 100.0f * tstat.zd, tstat.cossim);
|
||||
|
||||
const float weighted_bias = tstat.elements * tstat.total_sqract;
|
||||
const float weighted_zd = tstat.elements * tstat.zd;
|
||||
|
||||
@@ -5,6 +5,7 @@ find_package(Threads REQUIRED)
|
||||
add_library(mtmd
|
||||
mtmd.cpp
|
||||
mtmd-audio.cpp
|
||||
mtmd-image.cpp
|
||||
mtmd.h
|
||||
mtmd-helper.cpp
|
||||
mtmd-helper.h
|
||||
|
||||
@@ -51,7 +51,6 @@
|
||||
|
||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
||||
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
||||
#define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes"
|
||||
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
||||
|
||||
+21
-3
@@ -28,6 +28,13 @@ enum patch_merge_type {
|
||||
PATCH_MERGE_SPATIAL_UNPAD,
|
||||
};
|
||||
|
||||
enum resize_algo {
|
||||
RESIZE_ALGO_BILINEAR, // stretch to target resolution
|
||||
RESIZE_ALGO_BICUBIC, // center-crop when aspect ratio doesn't match
|
||||
RESIZE_ALGO_BICUBIC_PILLOW,
|
||||
// RESIZE_ALGO_LANCZOS, // TODO
|
||||
};
|
||||
|
||||
struct clip_hparams {
|
||||
int32_t image_size = 0;
|
||||
int32_t patch_size = 0;
|
||||
@@ -37,13 +44,26 @@ struct clip_hparams {
|
||||
int32_t n_head = 0;
|
||||
int32_t n_layer = 0;
|
||||
// idefics3
|
||||
int32_t n_merge = 0; // number of patch merges **per-side**
|
||||
|
||||
// for preprocessor
|
||||
int32_t image_longest_edge = 0;
|
||||
int32_t image_min_pixels = -1;
|
||||
int32_t image_max_pixels = -1;
|
||||
int32_t n_merge = 0; // number of patch merges **per-side**
|
||||
resize_algo image_resize_algo = RESIZE_ALGO_BICUBIC;
|
||||
bool image_resize_pad = true; // if false, center-crop will be applied when resizing
|
||||
std::array<uint8_t, 3> image_pad_color = {0, 0, 0};
|
||||
|
||||
// (preprocessor) for llava-uhd style models
|
||||
std::vector<clip_image_size> image_res_candidates;
|
||||
int32_t preproc_min_tiles = 0;
|
||||
int32_t preproc_max_tiles = 0;
|
||||
resize_algo image_resize_algo_rf = RESIZE_ALGO_BICUBIC;
|
||||
resize_algo image_resize_algo_ov = RESIZE_ALGO_BILINEAR;
|
||||
bool image_pad_rf = true; // if true, refined image will be padded (e.g. llava-1.6)
|
||||
bool image_pad_ov = false; // if true, overview image will be padded (e.g. llava-1.6)
|
||||
std::array<uint8_t, 3> image_pad_color_rf = {0, 0, 0}; // padding color for refined image
|
||||
std::array<uint8_t, 3> image_pad_color_ov = {0, 0, 0}; // padding color for overview image
|
||||
|
||||
float image_mean[3];
|
||||
float image_std[3];
|
||||
@@ -60,8 +80,6 @@ struct clip_hparams {
|
||||
float eps = 1e-6;
|
||||
float rope_theta = 0.0;
|
||||
|
||||
std::vector<clip_image_size> image_res_candidates; // for llava-uhd style models
|
||||
int32_t image_crop_resolution;
|
||||
std::unordered_set<int32_t> vision_feature_layer;
|
||||
int32_t attn_window_size = 0;
|
||||
int32_t n_wa_pattern = 0;
|
||||
|
||||
+62
-1398
File diff suppressed because it is too large
Load Diff
@@ -97,9 +97,6 @@ struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch
|
||||
*/
|
||||
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
|
||||
|
||||
/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
|
||||
bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
|
||||
|
||||
struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
|
||||
|
||||
bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
|
||||
|
||||
@@ -13,23 +13,20 @@
|
||||
|
||||
constexpr bool DEBUG = false;
|
||||
|
||||
void mtmd_audio_cache::fill_sin_cos_table(int n) {
|
||||
void mtmd_audio_cache::fill_sin_cos_table(uint32_t n) {
|
||||
sin_vals.resize(n);
|
||||
cos_vals.resize(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (uint32_t i = 0; i < n; i++) {
|
||||
double theta = (2 * M_PI * i) / n;
|
||||
sin_vals[i] = sinf(theta);
|
||||
cos_vals[i] = cosf(theta);
|
||||
}
|
||||
}
|
||||
|
||||
void mtmd_audio_cache::fill_hann_window(int length, bool periodic) {
|
||||
void mtmd_audio_cache::fill_hann_window(uint32_t length, bool periodic) {
|
||||
hann_window.resize(length);
|
||||
int offset = -1;
|
||||
if (periodic) {
|
||||
offset = 0;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
int offset = periodic ? 0 : -1;
|
||||
for (uint32_t i = 0; i < length; i++) {
|
||||
hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
||||
}
|
||||
}
|
||||
@@ -165,6 +162,7 @@ static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, fl
|
||||
// false = input is complex-valued (interleaved real/imag, stride 2)
|
||||
template <bool Inverse, bool RealInput>
|
||||
static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) {
|
||||
GGML_ASSERT(N > 0);
|
||||
const int n_sin_cos_vals = cache.sin_vals.size();
|
||||
|
||||
if (N == 1) {
|
||||
@@ -407,6 +405,8 @@ static bool log_mel_spectrogram(
|
||||
}
|
||||
|
||||
|
||||
GGML_ASSERT(params.n_fft_bins > 0);
|
||||
GGML_ASSERT(params.hop_length > 0);
|
||||
out.n_mel = params.n_mel;
|
||||
out.n_len = (n_samples - frame_size) / frame_step + 1;
|
||||
// TODO: handle these checks better
|
||||
@@ -438,6 +438,7 @@ static bool log_mel_spectrogram(
|
||||
|
||||
const int effective_n_len = n_samples_in / frame_step;
|
||||
if (params.norm_per_feature) {
|
||||
GGML_ASSERT(effective_n_len > 1);
|
||||
for (int i = 0; i < out.n_mel; i++) {
|
||||
double mean = 0;
|
||||
for (int j = 0; j < effective_n_len; ++j) {
|
||||
@@ -639,6 +640,7 @@ mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length
|
||||
padding_to_remove((n_fft - hop_length) / 2),
|
||||
ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT
|
||||
ifft_out(n_fft * 2 * 4, 0.0f) {
|
||||
GGML_ASSERT(n_fft > 0 && hop_length > 0 && hop_length <= n_fft);
|
||||
cache.fill_sin_cos_table(n_fft);
|
||||
cache.fill_hann_window(n_fft, true);
|
||||
}
|
||||
|
||||
@@ -33,9 +33,9 @@ struct mtmd_audio_cache {
|
||||
|
||||
mtmd_audio_mel_filters filters;
|
||||
|
||||
void fill_sin_cos_table(int n);
|
||||
void fill_sin_cos_table(uint32_t n);
|
||||
|
||||
void fill_hann_window(int length, bool periodic);
|
||||
void fill_hann_window(uint32_t length, bool periodic);
|
||||
|
||||
// Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
|
||||
// n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
|
||||
|
||||
@@ -127,6 +127,7 @@ struct decode_embd_batch {
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
|
||||
GGML_ASSERT(n_tokens > 0 && n_pos_per_embd > 0 && n_mmproj_embd > 0);
|
||||
pos .resize(n_tokens * n_pos_per_embd);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
@@ -157,6 +158,7 @@ struct decode_embd_batch {
|
||||
// M-RoPE for image
|
||||
void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
|
||||
GGML_ASSERT(n_pos_per_embd == 4);
|
||||
GGML_ASSERT(nx > 0 && ny > 0 && nx * ny == batch.n_tokens);
|
||||
seq_id_0[0] = seq_id;
|
||||
for (int y = 0; y < ny; y++) {
|
||||
for (int x = 0; x < nx; x++) {
|
||||
@@ -192,6 +194,7 @@ struct decode_embd_batch {
|
||||
}
|
||||
|
||||
llama_batch get_view(int offset, int n_tokens) {
|
||||
GGML_ASSERT(offset >= 0 && n_tokens > 0 && offset + n_tokens <= batch.n_tokens);
|
||||
llama_pos * pos_ptr;
|
||||
pos_view.clear();
|
||||
pos_view.reserve(n_tokens * n_pos_per_embd);
|
||||
@@ -235,6 +238,7 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch,
|
||||
llama_pos * new_n_past) {
|
||||
GGML_ASSERT(n_batch > 0);
|
||||
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
|
||||
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
@@ -312,6 +316,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
||||
int32_t n_batch,
|
||||
bool logits_last,
|
||||
llama_pos * new_n_past) {
|
||||
GGML_ASSERT(n_batch > 0);
|
||||
int32_t ret;
|
||||
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
|
||||
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||
@@ -508,6 +513,11 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *
|
||||
fseek(f, 0, SEEK_END);
|
||||
long file_size = ftell(f);
|
||||
fseek(f, 0, SEEK_SET);
|
||||
if (file_size < 0) {
|
||||
LOG_ERR("Failed to get file size of %s\n", fname);
|
||||
fclose(f);
|
||||
return nullptr;
|
||||
}
|
||||
buf.resize(file_size);
|
||||
|
||||
size_t n_read = fread(buf.data(), 1, file_size, f);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,150 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "clip-model.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#define MTMD_INTERNAL_HEADER
|
||||
|
||||
// base class, models must inherit from this class
|
||||
struct mtmd_image_preprocessor {
|
||||
const clip_hparams & hparams;
|
||||
|
||||
mtmd_image_preprocessor(const clip_ctx * ctx): hparams(*clip_get_hparams(ctx)) {}
|
||||
|
||||
virtual ~mtmd_image_preprocessor() = default;
|
||||
virtual bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) = 0;
|
||||
|
||||
void img_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]);
|
||||
void img_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst);
|
||||
};
|
||||
|
||||
/**
|
||||
* implementation of LLaVA-UHD:
|
||||
* - https://arxiv.org/pdf/2403.11703
|
||||
* - https://github.com/thunlp/LLaVA-UHD
|
||||
* - https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
|
||||
*
|
||||
* overview:
|
||||
* - an image always have a single overview (downscaled image)
|
||||
* - an image can have 0 or multiple slices, depending on the image size
|
||||
* - each slice can then be considered as a separate image
|
||||
*
|
||||
* note: the term "slice" and "tile" are used interchangeably
|
||||
*
|
||||
* for example:
|
||||
*
|
||||
* [overview] --> [slice 1] --> [slice 2]
|
||||
* | |
|
||||
* +--> [slice 3] --> [slice 4]
|
||||
*/
|
||||
struct mtmd_image_preprocessor_llava_uhd : mtmd_image_preprocessor {
|
||||
mtmd_image_preprocessor_llava_uhd(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
|
||||
struct slice_coordinates {
|
||||
int x;
|
||||
int y;
|
||||
clip_image_size size;
|
||||
};
|
||||
|
||||
struct slice_instructions {
|
||||
clip_image_size overview_size; // size of downscaled image
|
||||
clip_image_size refined_size; // size of image right before slicing (must be multiple of slice size)
|
||||
clip_image_size grid_size; // grid_size.width * grid_size.height = number of slices
|
||||
std::vector<slice_coordinates> slices;
|
||||
};
|
||||
|
||||
// LFM2 override this function to implement its custom slicing logic
|
||||
virtual slice_instructions get_slice_instructions(const clip_image_size & original_size);
|
||||
|
||||
std::vector<clip_image_u8_ptr> slice_image(const clip_image_u8 & img, const slice_instructions & inst, bool overview_first = true);
|
||||
|
||||
private:
|
||||
clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false);
|
||||
|
||||
clip_image_size resize_maintain_aspect_ratio(const clip_image_size & orig, const clip_image_size & target_max);
|
||||
|
||||
/**
|
||||
* Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
*
|
||||
* For example, when given a list of resolutions:
|
||||
* - 100x100
|
||||
* - 200x100
|
||||
* - 100x200
|
||||
* - 200x200
|
||||
*
|
||||
* And an input image of size 111x200, then 100x200 is the best fit (least wasted resolution).
|
||||
*
|
||||
* @param original_size The original size of the image
|
||||
* @param possible_resolutions A list of possible resolutions
|
||||
* @return The best fit resolution
|
||||
*/
|
||||
clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector<clip_image_size> & possible_resolutions);
|
||||
int ensure_divide(int length, int patch_size);
|
||||
clip_image_size get_refine_size(const clip_image_size & original_size, const clip_image_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false);
|
||||
clip_image_size get_best_grid(const int max_slice_nums, const int multiple, const float log_ratio);
|
||||
};
|
||||
|
||||
// downscale or upscale the input image to fixed size
|
||||
struct mtmd_image_preprocessor_fixed_size : mtmd_image_preprocessor {
|
||||
mtmd_image_preprocessor_fixed_size(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
};
|
||||
|
||||
// resize image to multiple of patch_size*n_merge, while preserving aspect ratio
|
||||
// if image_resize_pad is true, the resized image will be padded, otherwise it will be either stretched or center-cropped depending on image_resize_pad
|
||||
// this is used by models with native support for dynamic image size, for example: Qwen-VL, Pixtral, Kimi-VL, etc
|
||||
struct mtmd_image_preprocessor_dyn_size : mtmd_image_preprocessor {
|
||||
mtmd_image_preprocessor_dyn_size(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
};
|
||||
|
||||
// similar to mtmd_image_preprocessor_dyn_size, but resize the image to have longest edge equal to hparams.image_longest_edge, while preserving aspect ratio
|
||||
struct mtmd_image_preprocessor_longest_edge : mtmd_image_preprocessor {
|
||||
mtmd_image_preprocessor_longest_edge(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
};
|
||||
|
||||
// custom llava-uhd slicing logic for LFM2
|
||||
// ref: https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py
|
||||
struct mtmd_image_preprocessor_lfm2 : mtmd_image_preprocessor_llava_uhd {
|
||||
// ref: https://huggingface.co/LiquidAI/LFM2.5-VL-1.6B/blob/main/processor_config.json
|
||||
static constexpr int min_tiles = 2;
|
||||
static constexpr int max_tiles = 10;
|
||||
static constexpr float max_pixels_tolerance = 2.0f;
|
||||
static constexpr int tile_size = 512;
|
||||
|
||||
using mtmd_image_preprocessor_llava_uhd::mtmd_image_preprocessor_llava_uhd;
|
||||
slice_instructions get_slice_instructions(const clip_image_size & original_size) override;
|
||||
|
||||
private:
|
||||
clip_image_size find_closest_aspect_ratio(
|
||||
float aspect_ratio,
|
||||
const std::vector<clip_image_size> & target_ratios,
|
||||
int width, int height);
|
||||
std::vector<clip_image_size> get_target_ratios();
|
||||
clip_image_size get_grid_layout(int height, int width);
|
||||
};
|
||||
|
||||
struct mtmd_image_preprocessor_idefics3 : mtmd_image_preprocessor_llava_uhd {
|
||||
mtmd_image_preprocessor_idefics3(const clip_ctx * ctx) : mtmd_image_preprocessor_llava_uhd(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
};
|
||||
|
||||
struct mtmd_image_preprocessor_internvl : mtmd_image_preprocessor_llava_uhd {
|
||||
mtmd_image_preprocessor_internvl(const clip_ctx * ctx) : mtmd_image_preprocessor_llava_uhd(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
};
|
||||
|
||||
struct mtmd_image_preprocessor_deepseekocr : mtmd_image_preprocessor {
|
||||
mtmd_image_preprocessor_deepseekocr(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
};
|
||||
|
||||
struct mtmd_image_preprocessor_youtuvl : mtmd_image_preprocessor {
|
||||
mtmd_image_preprocessor_youtuvl(const clip_ctx * ctx) : mtmd_image_preprocessor(ctx) {}
|
||||
bool preprocess(const clip_image_u8 & img, clip_image_f32_batch & output) override;
|
||||
};
|
||||
+221
-134
@@ -2,6 +2,7 @@
|
||||
#include "clip-impl.h"
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-audio.h"
|
||||
#include "mtmd-image.h"
|
||||
#include "debug/mtmd-debug.h"
|
||||
|
||||
#include "llama.h"
|
||||
@@ -138,7 +139,7 @@ struct mtmd_context {
|
||||
|
||||
// for llava-uhd style models, we need special tokens in-between slices
|
||||
// minicpmv calls them "slices", llama 4 calls them "tiles"
|
||||
mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE;
|
||||
mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE;
|
||||
std::vector<llama_token> tok_ov_img_start; // overview image
|
||||
std::vector<llama_token> tok_ov_img_end; // overview image
|
||||
std::vector<llama_token> tok_slices_start; // start of all slices
|
||||
@@ -147,13 +148,14 @@ struct mtmd_context {
|
||||
std::vector<llama_token> tok_sli_img_end; // single slice end
|
||||
std::vector<llama_token> tok_sli_img_mid; // between 2 slices
|
||||
std::vector<llama_token> tok_row_end; // end of row
|
||||
bool tok_row_end_trail = false;
|
||||
bool ov_img_first = false;
|
||||
bool tok_row_end_trail = false;
|
||||
bool ov_img_first = false;
|
||||
|
||||
// string template for slice image delimiters with row/col (idefics3)
|
||||
std::string sli_img_start_tmpl;
|
||||
|
||||
std::unique_ptr<mtmd_audio_preprocessor> audio_preproc;
|
||||
std::unique_ptr<mtmd_image_preprocessor> image_preproc;
|
||||
|
||||
// TODO @ngxson : add timings
|
||||
|
||||
@@ -221,123 +223,193 @@ struct mtmd_context {
|
||||
|
||||
void init_vision() {
|
||||
GGML_ASSERT(ctx_v != nullptr);
|
||||
image_preproc.reset();
|
||||
|
||||
projector_type proj = clip_get_projector_type(ctx_v);
|
||||
int minicpmv_version = clip_is_minicpmv(ctx_v);
|
||||
if (minicpmv_version == 2) {
|
||||
// minicpmv 2.5 format:
|
||||
// <image> (overview) </image><slice><image> (slice) </image><image> (slice) </image>\n ... </slice>
|
||||
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_5;
|
||||
tok_ov_img_start = {lookup_token("<image>")};
|
||||
tok_ov_img_end = {lookup_token("</image>")};
|
||||
tok_slices_start = {lookup_token("<slice>")};
|
||||
tok_slices_end = {lookup_token("</slice>")};
|
||||
tok_sli_img_start = tok_ov_img_start;
|
||||
tok_sli_img_end = tok_ov_img_end;
|
||||
tok_row_end = {lookup_token("\n")};
|
||||
tok_row_end_trail = false; // no trailing end-of-row token
|
||||
ov_img_first = true;
|
||||
|
||||
} else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5 || minicpmv_version == 6 || minicpmv_version == 100045) {
|
||||
// minicpmv 2.6 format:
|
||||
// <image> (overview) </image><slice> (slice) </slice><slice> (slice) </slice>\n ...
|
||||
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6;
|
||||
tok_ov_img_start = {lookup_token("<image>")};
|
||||
tok_ov_img_end = {lookup_token("</image>")};
|
||||
tok_sli_img_start = {lookup_token("<slice>")};
|
||||
tok_sli_img_end = {lookup_token("</slice>")};
|
||||
tok_row_end = {lookup_token("\n")};
|
||||
tok_row_end_trail = false; // no trailing end-of-row token
|
||||
ov_img_first = true;
|
||||
switch (proj) {
|
||||
case PROJECTOR_TYPE_MLP:
|
||||
case PROJECTOR_TYPE_MLP_NORM:
|
||||
case PROJECTOR_TYPE_LDP:
|
||||
case PROJECTOR_TYPE_LDPV2:
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
case PROJECTOR_TYPE_JANUS_PRO:
|
||||
case PROJECTOR_TYPE_GLM_EDGE:
|
||||
{
|
||||
bool has_pinpoints = !clip_get_hparams(ctx_v)->image_res_candidates.empty();
|
||||
if (has_pinpoints) {
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_llava_uhd>(ctx_v);
|
||||
} else {
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_fixed_size>(ctx_v);
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_MINICPMV:
|
||||
{
|
||||
int minicpmv_version = clip_is_minicpmv(ctx_v);
|
||||
if (minicpmv_version == 2) {
|
||||
// minicpmv 2.5 format:
|
||||
// <image> (overview) </image><slice><image> (slice) </image><image> (slice) </image>\n ... </slice>
|
||||
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_5;
|
||||
tok_ov_img_start = {lookup_token("<image>")};
|
||||
tok_ov_img_end = {lookup_token("</image>")};
|
||||
tok_slices_start = {lookup_token("<slice>")};
|
||||
tok_slices_end = {lookup_token("</slice>")};
|
||||
tok_sli_img_start = tok_ov_img_start;
|
||||
tok_sli_img_end = tok_ov_img_end;
|
||||
tok_row_end = {lookup_token("\n")};
|
||||
tok_row_end_trail = false; // no trailing end-of-row token
|
||||
ov_img_first = true;
|
||||
|
||||
} else if (minicpmv_version != 0) {
|
||||
GGML_ASSERT(false && "unsupported minicpmv version");
|
||||
} else if (proj == PROJECTOR_TYPE_LLAMA4) {
|
||||
// llama 4 format:
|
||||
// <|image_start|>
|
||||
// (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|>
|
||||
// (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|>
|
||||
// ... <|tile_y_separator|> <-- trailing end-of-row token
|
||||
// <|image|> (overview) <-- overview image is last
|
||||
// <|image_end|>
|
||||
slice_tmpl = MTMD_SLICE_TMPL_LLAMA4;
|
||||
tok_ov_img_start = {lookup_token("<|image|>")};
|
||||
tok_sli_img_mid = {lookup_token("<|tile_x_separator|>")};
|
||||
tok_row_end = {lookup_token("<|tile_y_separator|>")};
|
||||
tok_row_end_trail = true; // add trailing end-of-row token
|
||||
ov_img_first = false; // overview image is last
|
||||
} else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5 || minicpmv_version == 6 || minicpmv_version == 100045) {
|
||||
// minicpmv 2.6 format:
|
||||
// <image> (overview) </image><slice> (slice) </slice><slice> (slice) </slice>\n ...
|
||||
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6;
|
||||
tok_ov_img_start = {lookup_token("<image>")};
|
||||
tok_ov_img_end = {lookup_token("</image>")};
|
||||
tok_sli_img_start = {lookup_token("<slice>")};
|
||||
tok_sli_img_end = {lookup_token("</slice>")};
|
||||
tok_row_end = {lookup_token("\n")};
|
||||
tok_row_end_trail = false; // no trailing end-of-row token
|
||||
ov_img_first = true;
|
||||
|
||||
} else if (minicpmv_version != 0) {
|
||||
throw std::runtime_error(string_format("unsupported minicpmv version: %d\n", minicpmv_version));
|
||||
}
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_llava_uhd>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
|
||||
img_beg = "<|vision_start|>";
|
||||
img_end = "<|vision_end|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_dyn_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_YOUTUVL:
|
||||
{
|
||||
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
|
||||
img_beg = "<|vision_start|>";
|
||||
img_end = "<|vision_end|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_youtuvl>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_GEMMA3NV:
|
||||
{
|
||||
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
||||
img_beg = "<start_of_image>";
|
||||
img_end = "<end_of_image>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_fixed_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
{
|
||||
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
||||
slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3;
|
||||
tok_ov_img_start = {lookup_token("\n\n"), lookup_token("<fake_token_around_image>"), lookup_token("<global-img>")};
|
||||
tok_ov_img_end = {lookup_token("<fake_token_around_image>")};
|
||||
tok_row_end = {lookup_token("\n")};
|
||||
sli_img_start_tmpl = "<fake_token_around_image><row_%d_col_%d>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_idefics3>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
{
|
||||
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
|
||||
img_end = "[IMG_END]";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_dyn_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PHI4:
|
||||
{
|
||||
// Phi-4 uses media marker insertion only. Keep image boundary text empty.
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_dyn_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
// (more details in mtmd_context constructor)
|
||||
img_beg = "<|image_start|>";
|
||||
img_end = "<|image_end|>";
|
||||
LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
|
||||
" https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_llava_uhd>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
{
|
||||
// <img> ... (image embeddings) ... </img>
|
||||
img_beg = "<img>";
|
||||
img_end = "</img>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_internvl>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
// <|media_start|> ... (image embeddings) ... <|media_end|>
|
||||
img_beg = "<|media_start|>";
|
||||
img_end = "<|media_end|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_dyn_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_KIMIK25:
|
||||
{
|
||||
// <|media_begin|> ... (image embeddings) ... <|media_end|>
|
||||
img_beg = "<|media_begin|>";
|
||||
img_end = "<|media_end|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_dyn_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
// <|im_start|> ... (image embeddings) ... <|im_end|>
|
||||
img_beg = "<|im_start|>";
|
||||
img_end = "<|im_end|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_longest_edge>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
{
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_fixed_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
{
|
||||
// multi-tile:
|
||||
// <|image_start|>
|
||||
// <|img_row_1_col_1|> (tile) <|img_row_1_col_2|> (tile) ...
|
||||
// <|img_thumbnail|> (thumbnail)
|
||||
// <|image_end|>
|
||||
// single-tile:
|
||||
// <|image_start|> (image) <|image_end|>
|
||||
img_beg = "<|image_start|>";
|
||||
img_end = "<|image_end|>";
|
||||
slice_tmpl = MTMD_SLICE_TMPL_LFM2;
|
||||
sli_img_start_tmpl = "<|img_row_%d_col_%d|>";
|
||||
tok_ov_img_start = {lookup_token("<|img_thumbnail|>")};
|
||||
ov_img_first = false;
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_lfm2>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLM4V:
|
||||
{
|
||||
// <|begin_of_image|> ... (image embeddings) ... <|end_of_image|>
|
||||
img_beg = "<|begin_of_image|>";
|
||||
img_end = "<|end_of_image|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_dyn_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PADDLEOCR:
|
||||
{
|
||||
// <|IMAGE_START|> ... (image embeddings) ... <|IMAGE_END|>
|
||||
img_beg = "<|IMAGE_START|>";
|
||||
img_end = "<|IMAGE_END|>";
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_dyn_size>(ctx_v);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_DEEPSEEKOCR:
|
||||
{
|
||||
img_end = "\n"; // prevent empty batch on llama-server
|
||||
image_preproc = std::make_unique<mtmd_image_preprocessor_deepseekocr>(ctx_v);
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error(string_format("%s: unexpected vision projector type %d\n", __func__, proj));
|
||||
}
|
||||
|
||||
// set boi/eoi
|
||||
if (proj == PROJECTOR_TYPE_GEMMA3 || proj == PROJECTOR_TYPE_GEMMA3NV) {
|
||||
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
||||
img_beg = "<start_of_image>";
|
||||
img_end = "<end_of_image>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
||||
slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3;
|
||||
tok_ov_img_start = {lookup_token("\n\n"), lookup_token("<fake_token_around_image>"), lookup_token("<global-img>")};
|
||||
tok_ov_img_end = {lookup_token("<fake_token_around_image>")};
|
||||
tok_row_end = {lookup_token("\n")};
|
||||
sli_img_start_tmpl = "<fake_token_around_image><row_%d_col_%d>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_PIXTRAL) {
|
||||
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
|
||||
img_end = "[IMG_END]";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_YOUTUVL) {
|
||||
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
|
||||
img_beg = "<|vision_start|>";
|
||||
img_end = "<|vision_end|>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_PHI4) {
|
||||
// Phi-4 uses media marker insertion only. Keep image boundary text empty.
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_LLAMA4) {
|
||||
// (more details in mtmd_context constructor)
|
||||
img_beg = "<|image_start|>";
|
||||
img_end = "<|image_end|>";
|
||||
LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
|
||||
" https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_INTERNVL) {
|
||||
// <img> ... (image embeddings) ... </img>
|
||||
img_beg = "<img>";
|
||||
img_end = "</img>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_LIGHTONOCR) {
|
||||
// <|im_start|> ... (image embeddings) ... <|im_end|>
|
||||
img_beg = "<|im_start|>";
|
||||
img_end = "<|im_end|>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_LFM2) {
|
||||
// multi-tile:
|
||||
// <|image_start|>
|
||||
// <|img_row_1_col_1|> (tile) <|img_row_1_col_2|> (tile) ...
|
||||
// <|img_thumbnail|> (thumbnail)
|
||||
// <|image_end|>
|
||||
// single-tile:
|
||||
// <|image_start|> (image) <|image_end|>
|
||||
img_beg = "<|image_start|>";
|
||||
img_end = "<|image_end|>";
|
||||
slice_tmpl = MTMD_SLICE_TMPL_LFM2;
|
||||
sli_img_start_tmpl = "<|img_row_%d_col_%d|>";
|
||||
tok_ov_img_start = {lookup_token("<|img_thumbnail|>")};
|
||||
ov_img_first = false;
|
||||
} else if (proj == PROJECTOR_TYPE_GLM4V) {
|
||||
img_beg = "<|begin_of_image|>";
|
||||
img_end = "<|end_of_image|>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_PADDLEOCR) {
|
||||
// <|IMAGE_START|> ... (image embeddings) ... <|IMAGE_END|>
|
||||
img_beg = "<|IMAGE_START|>";
|
||||
img_end = "<|IMAGE_END|>";
|
||||
}
|
||||
GGML_ASSERT(image_preproc != nullptr);
|
||||
}
|
||||
|
||||
void init_audio() {
|
||||
GGML_ASSERT(ctx_a != nullptr);
|
||||
audio_preproc.reset();
|
||||
|
||||
projector_type proj = clip_get_projector_type(ctx_a);
|
||||
|
||||
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
|
||||
@@ -347,36 +419,40 @@ struct mtmd_context {
|
||||
switch (proj) {
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_QWEN25O:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
{
|
||||
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
|
||||
aud_beg = "<|audio_bos|>";
|
||||
aud_end = "<|audio_eos|>";
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
{
|
||||
// [BEGIN_AUDIO] ... (embeddings) ...
|
||||
aud_beg = "[BEGIN_AUDIO]";
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
|
||||
break;
|
||||
{
|
||||
// <sound> ... (embeddings) ...
|
||||
aud_beg = "<sound>";
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
{
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2A:
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_conformer>(ctx_a);
|
||||
break;
|
||||
{
|
||||
audio_preproc = std::make_unique<mtmd_audio_preprocessor_conformer>(ctx_a);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported audio projector type");
|
||||
throw std::runtime_error(string_format("%s: unexpected audio projector type %d\n", __func__, proj));
|
||||
}
|
||||
|
||||
// initialize audio preprocessor
|
||||
GGML_ASSERT(audio_preproc != nullptr);
|
||||
audio_preproc->initialize();
|
||||
|
||||
// set special tokens
|
||||
if (proj == PROJECTOR_TYPE_QWEN2A) {
|
||||
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
|
||||
aud_beg = "<|audio_bos|>";
|
||||
aud_end = "<|audio_eos|>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_ULTRAVOX) {
|
||||
// [BEGIN_AUDIO] ... (embeddings) ...
|
||||
aud_beg = "[BEGIN_AUDIO]";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_MUSIC_FLAMINGO) {
|
||||
// <sound> ... (embeddings) ...
|
||||
aud_beg = "<sound>";
|
||||
}
|
||||
}
|
||||
|
||||
// get clip ctx based on chunk type
|
||||
@@ -565,6 +641,11 @@ struct mtmd_tokenizer {
|
||||
add_text(ctx->img_beg, true); // add image begin token
|
||||
}
|
||||
|
||||
// sanity check
|
||||
GGML_ASSERT(bitmap->nx > 0 && bitmap->ny > 0);
|
||||
GGML_ASSERT(bitmap->data.size() == (size_t)bitmap->nx * bitmap->ny * 3);
|
||||
GGML_ASSERT(ctx->image_preproc != nullptr);
|
||||
|
||||
// convert mtmd_bitmap to clip_image_u8
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
img_u8->nx = bitmap->nx;
|
||||
@@ -574,7 +655,7 @@ struct mtmd_tokenizer {
|
||||
|
||||
// preprocess image
|
||||
clip_image_f32_batch batch_f32;
|
||||
bool ok = clip_image_preprocess(ctx->ctx_v, img_u8.get(), &batch_f32);
|
||||
bool ok = ctx->image_preproc->preprocess(*img_u8, batch_f32);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to preprocess image\n");
|
||||
return 2;
|
||||
@@ -696,6 +777,11 @@ struct mtmd_tokenizer {
|
||||
add_text(ctx->aud_beg, true); // add audio begin token
|
||||
}
|
||||
|
||||
// sanity check
|
||||
GGML_ASSERT(ctx->audio_preproc != nullptr);
|
||||
GGML_ASSERT(bitmap->data.size() > sizeof(float));
|
||||
GGML_ASSERT(bitmap->data.size() % sizeof(float) == 0);
|
||||
|
||||
// preprocess audio
|
||||
std::vector<mtmd_audio_mel> mel_spec_chunks;
|
||||
const float * samples = (const float *)bitmap->data.data();
|
||||
@@ -1225,7 +1311,8 @@ void mtmd_debug_preprocess_image(mtmd_context * ctx, const std::vector<uint8_t>
|
||||
img_u8.ny = ny;
|
||||
img_u8.buf = rgb_values;
|
||||
clip_image_f32_batch batch_f32;
|
||||
bool ok = clip_image_preprocess(ctx->ctx_v, &img_u8, &batch_f32);
|
||||
GGML_ASSERT(ctx->image_preproc != nullptr);
|
||||
bool ok = ctx->image_preproc->preprocess(img_u8, batch_f32);
|
||||
if (!ok) {
|
||||
LOG_ERR("%s: failed to preprocess image\n", __func__);
|
||||
return;
|
||||
|
||||
+23
-14
@@ -13,6 +13,8 @@ add_library(${TARGET} STATIC
|
||||
server-common.h
|
||||
server-context.cpp
|
||||
server-context.h
|
||||
server-tools.cpp
|
||||
server-tools.h
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
@@ -35,22 +37,29 @@ set(TARGET_SRCS
|
||||
server-models.cpp
|
||||
server-models.h
|
||||
)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html.gz
|
||||
loading.html
|
||||
)
|
||||
|
||||
foreach(asset ${PUBLIC_ASSETS})
|
||||
set(input "${CMAKE_CURRENT_SOURCE_DIR}/public/${asset}")
|
||||
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
|
||||
list(APPEND TARGET_SRCS ${output})
|
||||
add_custom_command(
|
||||
DEPENDS "${input}"
|
||||
OUTPUT "${output}"
|
||||
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
|
||||
option(LLAMA_BUILD_WEBUI "Build the embedded Web UI" ON)
|
||||
|
||||
if (LLAMA_BUILD_WEBUI)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html.gz
|
||||
loading.html
|
||||
)
|
||||
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
|
||||
endforeach()
|
||||
|
||||
foreach(asset ${PUBLIC_ASSETS})
|
||||
set(input "${CMAKE_CURRENT_SOURCE_DIR}/public/${asset}")
|
||||
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
|
||||
list(APPEND TARGET_SRCS ${output})
|
||||
add_custom_command(
|
||||
DEPENDS "${input}"
|
||||
OUTPUT "${output}"
|
||||
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
|
||||
)
|
||||
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
|
||||
endforeach()
|
||||
add_definitions(-DLLAMA_BUILD_WEBUI)
|
||||
else()
|
||||
endif()
|
||||
|
||||
add_executable(${TARGET} ${TARGET_SRCS})
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
||||
@@ -125,6 +125,61 @@ The framework automatically starts a `llama-server` instance, sends requests, an
|
||||
|
||||
For detailed instructions, see the [test documentation](./tests/README.md).
|
||||
|
||||
### API for tools
|
||||
|
||||
This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future.
|
||||
|
||||
**GET /tools**
|
||||
|
||||
Get a list of tools, each tool has these fields:
|
||||
- `tool` (string): the ID name of the tool, to be used in POST call. Example: `read_file`
|
||||
- `display_name` (string): the name to be displayed on UI. Example: `Read file`
|
||||
- `type` (string): always be `"builtin"` for now
|
||||
- `permissions` (object): a mapping string --> boolean that indicates the permission required by this tool. This is useful for the UI to ask the user before calling the tool. For now, the only permission supported is `"write"`
|
||||
- `definition` (object): the OAI-compat definition of this tool
|
||||
|
||||
**POST /tools**
|
||||
|
||||
Invoke a tool call, request body is a JSON object with:
|
||||
- `tool` (string): the name of the tool
|
||||
- `params` (object): a mapping from argument name (string) to argument value
|
||||
|
||||
Returns JSON object. There are two response formats:
|
||||
|
||||
Format 1: Plain text. The text will be placed into a field called `plain_text_response`, example:
|
||||
|
||||
```json
|
||||
{
|
||||
"plain_text_response": "this is a text response"
|
||||
}
|
||||
```
|
||||
|
||||
The client should extract this value and place it inside message content (note: content is no longer a JSON), example
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "this is a text response"
|
||||
}
|
||||
```
|
||||
|
||||
Format 2: Normal JSON response, example:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "cannot open this file"
|
||||
}
|
||||
```
|
||||
|
||||
That requires `JSON.stringify` when formatted to message content:
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "{\"error\":\"cannot open this file\"}"
|
||||
}
|
||||
```
|
||||
|
||||
### Notable Related PRs
|
||||
|
||||
- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443
|
||||
|
||||
+16
-1
@@ -36,7 +36,6 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--license` | show source code license and dependencies |
|
||||
| `-cl, --cache-list` | show list of models in cache |
|
||||
| `--completion-bash` | print source-able bash completion script for llama.cpp |
|
||||
| `--verbose-prompt` | print a verbose prompt before generation (default: false) |
|
||||
| `-t, --threads N` | number of CPU threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
|
||||
| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
|
||||
| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
|
||||
@@ -189,11 +188,13 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--tags STRING` | set model tags, comma-separated (informational, not used for routing)<br/>(env: LLAMA_ARG_TAGS) |
|
||||
| `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
|
||||
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
|
||||
| `--reuse-port` | allow multiple sockets to bind to the same port (default: disabled)<br/>(env: LLAMA_ARG_REUSE_PORT) |
|
||||
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
|
||||
| `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )<br/>(env: LLAMA_ARG_API_PREFIX) |
|
||||
| `--webui-config JSON` | JSON that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG) |
|
||||
| `--webui-config-file PATH` | JSON file that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG_FILE) |
|
||||
| `--webui-mcp-proxy, --no-webui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)<br/>(env: LLAMA_ARG_WEBUI_MCP_PROXY) |
|
||||
| `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)<br/>specify "all" to enable all tools<br/>available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff<br/>(env: LLAMA_ARG_TOOLS) |
|
||||
| `--webui, --no-webui` | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_WEBUI) |
|
||||
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
|
||||
| `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
|
||||
@@ -293,6 +294,12 @@ It is currently available in the following endpoints:
|
||||
|
||||
For more details, please refer to [multimodal documentation](../../docs/multimodal.md)
|
||||
|
||||
### Built-in tools support
|
||||
|
||||
The server includes a set of built-in tools that enable the LLM to access the local file system directly from the Web UI.
|
||||
|
||||
To use this feature, start the server with `--tools all`. You can also enable only specific tools by passing a comma-separated list: `--tools name1,name2,...`. Run `--help` for the full list of available tool names.
|
||||
|
||||
## Build
|
||||
|
||||
`llama-server` is built alongside everything else from the root of the project
|
||||
@@ -1438,6 +1445,14 @@ curl http://localhost:8080/v1/messages/count_tokens \
|
||||
{"input_tokens": 10}
|
||||
```
|
||||
|
||||
## Server built-in tools
|
||||
|
||||
The server exposes a REST API under `/tools` that allows the Web UI to call built-in tools. This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future.
|
||||
|
||||
**Please do NOT use this endpoint in a downstream application**
|
||||
|
||||
For further documentation about this endpoint, please refer to [server internal documentation](./README-dev.md)
|
||||
|
||||
## Using multiple models
|
||||
|
||||
`llama-server` can be launched in a **router mode** that exposes an API for dynamically loading and unloading models. The main process (the "router") automatically forwards each request to the appropriate model instance.
|
||||
|
||||
Binary file not shown.
@@ -1110,7 +1110,7 @@ json oaicompat_chat_params_parse(
|
||||
reasoning_budget = json_value(body, "thinking_budget_tokens", -1);
|
||||
}
|
||||
|
||||
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
|
||||
if (!chat_params.thinking_end_tag.empty()) {
|
||||
llama_params["reasoning_budget_tokens"] = reasoning_budget;
|
||||
llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
|
||||
llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
|
||||
|
||||
@@ -8,9 +8,11 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#ifdef LLAMA_BUILD_WEBUI
|
||||
// auto generated files (see README.md for details)
|
||||
#include "index.html.gz.hpp"
|
||||
#include "loading.html.hpp"
|
||||
#endif
|
||||
|
||||
//
|
||||
// HTTP implementation using cpp-httplib
|
||||
@@ -110,6 +112,22 @@ bool server_http_context::init(const common_params & params) {
|
||||
// set timeouts and change hostname and port
|
||||
srv->set_read_timeout (params.timeout_read);
|
||||
srv->set_write_timeout(params.timeout_write);
|
||||
srv->set_socket_options([reuse_port = params.reuse_port](socket_t sock) {
|
||||
int opt = 1;
|
||||
#ifdef _WIN32
|
||||
const char * optval = (const char *)&opt;
|
||||
#else
|
||||
const void * optval = &opt;
|
||||
#endif
|
||||
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, optval, sizeof(opt));
|
||||
if (reuse_port) {
|
||||
#ifdef SO_REUSEPORT
|
||||
setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, optval, sizeof(opt));
|
||||
#else
|
||||
LOG_WRN("%s: SO_REUSEPORT is not supported\n", __func__);
|
||||
#endif
|
||||
}
|
||||
});
|
||||
|
||||
if (params.api_keys.size() == 1) {
|
||||
auto key = params.api_keys[0];
|
||||
@@ -181,11 +199,14 @@ bool server_http_context::init(const common_params & params) {
|
||||
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
|
||||
bool ready = is_ready.load();
|
||||
if (!ready) {
|
||||
#ifdef LLAMA_BUILD_WEBUI
|
||||
auto tmp = string_split<std::string>(req.path, '.');
|
||||
if (req.path == "/" || tmp.back() == "html") {
|
||||
res.status = 503;
|
||||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||
} else {
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
// no endpoints is allowed to be accessed when the server is not ready
|
||||
// this is to prevent any data races or inconsistent states
|
||||
res.status = 503;
|
||||
@@ -255,6 +276,7 @@ bool server_http_context::init(const common_params & params) {
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
#ifdef LLAMA_BUILD_WEBUI
|
||||
// using embedded static index.html
|
||||
srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
|
||||
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
||||
@@ -268,6 +290,7 @@ bool server_http_context::init(const common_params & params) {
|
||||
}
|
||||
return false;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
@@ -478,19 +478,17 @@ task_params server_task::params_from_json_cmpl(
|
||||
// Parse reasoning budget sampler parameters
|
||||
{
|
||||
const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
|
||||
if (budget >= 0) {
|
||||
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
|
||||
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
||||
const auto message = json_value(data, "reasoning_budget_message", std::string());
|
||||
params.sampling.reasoning_budget_tokens = budget;
|
||||
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
|
||||
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
||||
const auto message = json_value(data, "reasoning_budget_message", std::string());
|
||||
params.sampling.reasoning_budget_tokens = budget;
|
||||
|
||||
if (!start_tag.empty()) {
|
||||
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
|
||||
}
|
||||
if (!end_tag.empty()) {
|
||||
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
|
||||
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
|
||||
}
|
||||
if (!start_tag.empty()) {
|
||||
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
|
||||
}
|
||||
if (!end_tag.empty()) {
|
||||
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
|
||||
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
|
||||
|
||||
SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
|
||||
budget, params.sampling.generation_prompt.c_str(),
|
||||
|
||||
@@ -0,0 +1,768 @@
|
||||
#include "server-tools.h"
|
||||
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <climits>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
//
|
||||
// internal helpers
|
||||
//
|
||||
|
||||
static std::vector<char *> to_cstr_vec(const std::vector<std::string> & v) {
|
||||
std::vector<char *> r;
|
||||
r.reserve(v.size() + 1);
|
||||
for (const auto & s : v) {
|
||||
r.push_back(const_cast<char *>(s.c_str()));
|
||||
}
|
||||
r.push_back(nullptr);
|
||||
return r;
|
||||
}
|
||||
|
||||
struct run_proc_result {
|
||||
std::string output;
|
||||
int exit_code = -1;
|
||||
bool timed_out = false;
|
||||
};
|
||||
|
||||
static run_proc_result run_process(
|
||||
const std::vector<std::string> & args,
|
||||
size_t max_output,
|
||||
int timeout_secs) {
|
||||
run_proc_result res;
|
||||
|
||||
subprocess_s proc;
|
||||
auto argv = to_cstr_vec(args);
|
||||
|
||||
int options = subprocess_option_no_window
|
||||
| subprocess_option_combined_stdout_stderr
|
||||
| subprocess_option_inherit_environment
|
||||
| subprocess_option_search_user_path;
|
||||
|
||||
if (subprocess_create(argv.data(), options, &proc) != 0) {
|
||||
res.output = "failed to spawn process";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::atomic<bool> done{false};
|
||||
std::atomic<bool> timed_out{false};
|
||||
|
||||
std::thread timeout_thread([&]() {
|
||||
auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_secs);
|
||||
while (!done.load()) {
|
||||
if (std::chrono::steady_clock::now() >= deadline) {
|
||||
timed_out.store(true);
|
||||
subprocess_terminate(&proc);
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
});
|
||||
|
||||
FILE * f = subprocess_stdout(&proc);
|
||||
std::string output;
|
||||
bool truncated = false;
|
||||
if (f) {
|
||||
char buf[4096];
|
||||
while (fgets(buf, sizeof(buf), f) != nullptr) {
|
||||
if (!truncated) {
|
||||
size_t len = strlen(buf);
|
||||
if (output.size() + len <= max_output) {
|
||||
output.append(buf, len);
|
||||
} else {
|
||||
output.append(buf, max_output - output.size());
|
||||
truncated = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done.store(true);
|
||||
if (timeout_thread.joinable()) {
|
||||
timeout_thread.join();
|
||||
}
|
||||
|
||||
subprocess_join(&proc, &res.exit_code);
|
||||
subprocess_destroy(&proc);
|
||||
|
||||
res.output = output;
|
||||
res.timed_out = timed_out.load();
|
||||
if (truncated) {
|
||||
res.output += "\n[output truncated]";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_tool::to_json() {
|
||||
return {
|
||||
{"display_name", display_name},
|
||||
{"tool", name},
|
||||
{"type", "builtin"},
|
||||
{"permissions", json{
|
||||
{"write", permission_write}
|
||||
}},
|
||||
{"definition", get_definition()},
|
||||
};
|
||||
}
|
||||
|
||||
//
|
||||
// read_file: read a file with optional line range and line-number prefix
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_READ_FILE_MAX_SIZE = 16 * 1024; // 16 KB
|
||||
|
||||
struct server_tool_read_file : server_tool {
|
||||
server_tool_read_file() {
|
||||
name = "read_file";
|
||||
display_name = "Read file";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Read the contents of a file. Optionally specify a 1-based line range. "
|
||||
"If append_loc is true, each line is prefixed with its line number (e.g. \"1\u2192 ...\")."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Path to the file"}}},
|
||||
{"start_line", {{"type", "integer"}, {"description", "First line to read, 1-based (default: 1)"}}},
|
||||
{"end_line", {{"type", "integer"}, {"description", "Last line to read, 1-based inclusive (default: end of file)"}}},
|
||||
{"append_loc", {{"type", "boolean"}, {"description", "Prefix each line with its line number"}}},
|
||||
}},
|
||||
{"required", json::array({"path"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
int start_line = json_value(params, "start_line", 1);
|
||||
int end_line = json_value(params, "end_line", -1); // -1 = no limit
|
||||
bool append_loc = json_value(params, "append_loc", false);
|
||||
|
||||
std::error_code ec;
|
||||
uintmax_t file_size = fs::file_size(path, ec);
|
||||
if (ec) {
|
||||
return {{"error", "cannot stat file: " + ec.message()}};
|
||||
}
|
||||
if (file_size > SERVER_TOOL_READ_FILE_MAX_SIZE && end_line == -1) {
|
||||
return {{"error", string_format(
|
||||
"file too large (%zu bytes, max %zu). Use start_line/end_line to read a portion.",
|
||||
(size_t)file_size, SERVER_TOOL_READ_FILE_MAX_SIZE)}};
|
||||
}
|
||||
|
||||
std::ifstream f(path);
|
||||
if (!f) {
|
||||
return {{"error", "failed to open file: " + path}};
|
||||
}
|
||||
|
||||
std::string result;
|
||||
std::string line;
|
||||
int lineno = 0;
|
||||
|
||||
while (std::getline(f, line)) {
|
||||
lineno++;
|
||||
if (lineno < start_line) continue;
|
||||
if (end_line != -1 && lineno > end_line) break;
|
||||
|
||||
std::string out_line;
|
||||
if (append_loc) {
|
||||
out_line = std::to_string(lineno) + "\u2192 " + line + "\n";
|
||||
} else {
|
||||
out_line = line + "\n";
|
||||
}
|
||||
|
||||
if (result.size() + out_line.size() > SERVER_TOOL_READ_FILE_MAX_SIZE) {
|
||||
result += "[output truncated]";
|
||||
break;
|
||||
}
|
||||
result += out_line;
|
||||
}
|
||||
|
||||
return {{"plain_text_response", result}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// file_glob_search: find files matching a glob pattern under a base directory
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_FILE_SEARCH_MAX_RESULTS = 100;
|
||||
|
||||
struct server_tool_file_glob_search : server_tool {
|
||||
server_tool_file_glob_search() {
|
||||
name = "file_glob_search";
|
||||
display_name = "File search";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Recursively search for files matching a glob pattern under a directory."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Base directory to search in"}}},
|
||||
{"include", {{"type", "string"}, {"description", "Glob pattern for files to include (e.g. \"**/*.cpp\"). Default: **"}}},
|
||||
{"exclude", {{"type", "string"}, {"description", "Glob pattern for files to exclude"}}},
|
||||
}},
|
||||
{"required", json::array({"path"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string base = params.at("path").get<std::string>();
|
||||
std::string include = json_value(params, "include", std::string("**"));
|
||||
std::string exclude = json_value(params, "exclude", std::string(""));
|
||||
|
||||
std::ostringstream output_text;
|
||||
size_t count = 0;
|
||||
|
||||
std::error_code ec;
|
||||
for (const auto & entry : fs::recursive_directory_iterator(base,
|
||||
fs::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) continue;
|
||||
|
||||
std::string rel = fs::relative(entry.path(), base, ec).string();
|
||||
if (ec) continue;
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(include, rel)) continue;
|
||||
if (!exclude.empty() && glob_match(exclude, rel)) continue;
|
||||
|
||||
output_text << entry.path().string() << "\n";
|
||||
if (++count >= SERVER_TOOL_FILE_SEARCH_MAX_RESULTS) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
output_text << "\n---\nTotal matches: " << count << "\n";
|
||||
|
||||
return {{"plain_text_response", output_text.str()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// grep_search: search for a regex pattern in files
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_GREP_SEARCH_MAX_RESULTS = 100;
|
||||
|
||||
struct server_tool_grep_search : server_tool {
|
||||
server_tool_grep_search() {
|
||||
name = "grep_search";
|
||||
display_name = "Grep search";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Search for a regex pattern in files under a path. Returns matching lines."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "File or directory to search in"}}},
|
||||
{"pattern", {{"type", "string"}, {"description", "Regular expression pattern to search for"}}},
|
||||
{"include", {{"type", "string"}, {"description", "Glob pattern to filter files (default: **)"}}},
|
||||
{"exclude", {{"type", "string"}, {"description", "Glob pattern to exclude files"}}},
|
||||
{"return_line_numbers", {{"type", "boolean"}, {"description", "If true, include line numbers in results"}}},
|
||||
}},
|
||||
{"required", json::array({"path", "pattern"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
std::string pat_str = params.at("pattern").get<std::string>();
|
||||
std::string include = json_value(params, "include", std::string("**"));
|
||||
std::string exclude = json_value(params, "exclude", std::string(""));
|
||||
bool show_lineno = json_value(params, "return_line_numbers", false);
|
||||
|
||||
std::regex pattern;
|
||||
try {
|
||||
pattern = std::regex(pat_str);
|
||||
} catch (const std::regex_error & e) {
|
||||
return {{"error", std::string("invalid regex: ") + e.what()}};
|
||||
}
|
||||
|
||||
std::ostringstream output_text;
|
||||
size_t total = 0;
|
||||
|
||||
auto search_file = [&](const fs::path & fpath) {
|
||||
std::ifstream f(fpath);
|
||||
if (!f) return;
|
||||
std::string line;
|
||||
int lineno = 0;
|
||||
while (std::getline(f, line) && total < SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) {
|
||||
lineno++;
|
||||
if (std::regex_search(line, pattern)) {
|
||||
output_text << fpath.string() << ":";
|
||||
if (show_lineno) {
|
||||
output_text << lineno << ":";
|
||||
}
|
||||
output_text << line << "\n";
|
||||
total++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::error_code ec;
|
||||
if (fs::is_regular_file(path, ec)) {
|
||||
search_file(path);
|
||||
} else if (fs::is_directory(path, ec)) {
|
||||
for (const auto & entry : fs::recursive_directory_iterator(path,
|
||||
fs::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) continue;
|
||||
if (total >= SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) break;
|
||||
|
||||
std::string rel = fs::relative(entry.path(), path, ec).string();
|
||||
if (ec) continue;
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(include, rel)) continue;
|
||||
if (!exclude.empty() && glob_match(exclude, rel)) continue;
|
||||
|
||||
search_file(entry.path());
|
||||
}
|
||||
} else {
|
||||
return {{"error", "path does not exist: " + path}};
|
||||
}
|
||||
|
||||
output_text << "\n\n---\nTotal matches: " << total << "\n";
|
||||
|
||||
return {{"plain_text_response", output_text.str()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// exec_shell_command: run an arbitrary shell command
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE = 16 * 1024; // 16 KB
|
||||
static constexpr int SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT = 60; // seconds
|
||||
|
||||
struct server_tool_exec_shell_command : server_tool {
|
||||
server_tool_exec_shell_command() {
|
||||
name = "exec_shell_command";
|
||||
display_name = "Execute shell command";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Execute a shell command and return its output (stdout and stderr combined)."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"command", {{"type", "string"}, {"description", "Shell command to execute"}}},
|
||||
{"timeout", {{"type", "integer"}, {"description", string_format("Timeout in seconds (default 10, max %d)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT)}}},
|
||||
{"max_output_size", {{"type", "integer"}, {"description", string_format("Maximum output size in bytes (default %zu)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE)}}},
|
||||
}},
|
||||
{"required", json::array({"command"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string command = params.at("command").get<std::string>();
|
||||
int timeout = json_value(params, "timeout", 10);
|
||||
size_t max_output = (size_t) json_value(params, "max_output_size", (int) SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
|
||||
|
||||
timeout = std::min(timeout, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT);
|
||||
max_output = std::min(max_output, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
|
||||
|
||||
#ifdef _WIN32
|
||||
std::vector<std::string> args = {"cmd", "/c", command};
|
||||
#else
|
||||
std::vector<std::string> args = {"sh", "-c", command};
|
||||
#endif
|
||||
|
||||
auto res = run_process(args, max_output, timeout);
|
||||
|
||||
std::string text_output = res.output;
|
||||
text_output += string_format("\n[exit code: %d]", res.exit_code);
|
||||
if (res.timed_out) {
|
||||
text_output += " [exit due to timed out]";
|
||||
}
|
||||
|
||||
return {{"plain_text_response", text_output}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// write_file: create or overwrite a file
|
||||
//
|
||||
|
||||
struct server_tool_write_file : server_tool {
|
||||
server_tool_write_file() {
|
||||
name = "write_file";
|
||||
display_name = "Write file";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Write content to a file, creating it (including parent directories) if it does not exist. May use with edit_file for more complex edits."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Path of the file to write"}}},
|
||||
{"content", {{"type", "string"}, {"description", "Content to write"}}},
|
||||
}},
|
||||
{"required", json::array({"path", "content"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
std::string content = params.at("content").get<std::string>();
|
||||
|
||||
std::error_code ec;
|
||||
fs::path fpath(path);
|
||||
if (fpath.has_parent_path()) {
|
||||
fs::create_directories(fpath.parent_path(), ec);
|
||||
if (ec) {
|
||||
return {{"error", "failed to create directories: " + ec.message()}};
|
||||
}
|
||||
}
|
||||
|
||||
std::ofstream f(path, std::ios::binary);
|
||||
if (!f) {
|
||||
return {{"error", "failed to open file for writing: " + path}};
|
||||
}
|
||||
f << content;
|
||||
if (!f) {
|
||||
return {{"error", "failed to write file: " + path}};
|
||||
}
|
||||
|
||||
return {{"result", "file written successfully"}, {"path", path}, {"bytes", content.size()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// edit_file: edit file content via line-based changes
|
||||
//
|
||||
|
||||
struct server_tool_edit_file : server_tool {
|
||||
server_tool_edit_file() {
|
||||
name = "edit_file";
|
||||
display_name = "Edit file";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description",
|
||||
"Edit a file by applying a list of line-based changes. "
|
||||
"Each change targets a 1-based inclusive line range and has a mode: "
|
||||
"\"replace\" (replace lines with content), "
|
||||
"\"delete\" (remove lines, content must be empty string), "
|
||||
"\"append\" (insert content after line_end). "
|
||||
"Set line_start to -1 to target the end of file (line_end is ignored in that case). "
|
||||
"Changes must not overlap. They are applied in reverse line order automatically."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Path to the file to edit"}}},
|
||||
{"changes", {
|
||||
{"type", "array"},
|
||||
{"description", "List of changes to apply"},
|
||||
{"items", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"mode", {{"type", "string"}, {"description", "\"replace\", \"delete\", or \"append\""}}},
|
||||
{"line_start", {{"type", "integer"}, {"description", "First line of the range (1-based); use -1 for end of file"}}},
|
||||
{"line_end", {{"type", "integer"}, {"description", "Last line of the range (1-based, inclusive); ignored when line_start is -1"}}},
|
||||
{"content", {{"type", "string"}, {"description", "Content to insert; must be empty string for delete mode"}}},
|
||||
}},
|
||||
{"required", json::array({"mode", "line_start", "line_end", "content"})},
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"path", "changes"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
const json & changes = params.at("changes");
|
||||
|
||||
if (!changes.is_array()) {
|
||||
return {{"error", "\"changes\" must be an array"}};
|
||||
}
|
||||
|
||||
// read file into lines
|
||||
std::ifstream fin(path);
|
||||
if (!fin) {
|
||||
return {{"error", "failed to open file: " + path}};
|
||||
}
|
||||
std::vector<std::string> lines;
|
||||
{
|
||||
std::string line;
|
||||
while (std::getline(fin, line)) {
|
||||
lines.push_back(line);
|
||||
}
|
||||
}
|
||||
fin.close();
|
||||
|
||||
// validate and collect changes, then sort descending by line_start
|
||||
struct change_entry {
|
||||
std::string mode;
|
||||
int line_start; // 1-based
|
||||
int line_end; // 1-based inclusive
|
||||
std::string content;
|
||||
};
|
||||
std::vector<change_entry> entries;
|
||||
entries.reserve(changes.size());
|
||||
|
||||
for (const auto & ch : changes) {
|
||||
change_entry e;
|
||||
e.mode = ch.at("mode").get<std::string>();
|
||||
e.line_start = ch.at("line_start").get<int>();
|
||||
e.line_end = ch.at("line_end").get<int>();
|
||||
e.content = ch.at("content").get<std::string>();
|
||||
|
||||
if (e.mode != "replace" && e.mode != "delete" && e.mode != "append") {
|
||||
return {{"error", "invalid mode \"" + e.mode + "\"; must be replace, delete, or append"}};
|
||||
}
|
||||
if (e.mode == "delete" && !e.content.empty()) {
|
||||
return {{"error", "content must be empty string for delete mode"}};
|
||||
}
|
||||
int n = (int) lines.size();
|
||||
if (e.line_start == -1) {
|
||||
// -1 means end of file; line_end is ignored — normalize to point past last line
|
||||
e.line_start = n + 1;
|
||||
e.line_end = n + 1;
|
||||
} else {
|
||||
if (e.line_start < 1 || e.line_end < e.line_start) {
|
||||
return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}};
|
||||
}
|
||||
if (e.line_end > n) {
|
||||
return {{"error", string_format("line_end %d exceeds file length %d", e.line_end, n)}};
|
||||
}
|
||||
}
|
||||
entries.push_back(std::move(e));
|
||||
}
|
||||
|
||||
// sort descending so earlier-indexed changes don't shift later ones
|
||||
std::sort(entries.begin(), entries.end(), [](const change_entry & a, const change_entry & b) {
|
||||
return a.line_start > b.line_start;
|
||||
});
|
||||
|
||||
// apply changes (0-based indices internally)
|
||||
for (const auto & e : entries) {
|
||||
int idx_start = e.line_start - 1; // 0-based
|
||||
int idx_end = e.line_end - 1; // 0-based inclusive
|
||||
|
||||
// split content into lines (preserve trailing newline awareness)
|
||||
std::vector<std::string> new_lines;
|
||||
if (!e.content.empty()) {
|
||||
std::istringstream ss(e.content);
|
||||
std::string ln;
|
||||
while (std::getline(ss, ln)) {
|
||||
new_lines.push_back(ln);
|
||||
}
|
||||
// if content ends with \n, getline consumed it — no extra empty line needed
|
||||
// if content does NOT end with \n, last line is still captured correctly
|
||||
}
|
||||
|
||||
if (e.mode == "replace") {
|
||||
// erase [idx_start, idx_end] and insert new_lines
|
||||
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
|
||||
lines.insert(lines.begin() + idx_start, new_lines.begin(), new_lines.end());
|
||||
} else if (e.mode == "delete") {
|
||||
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
|
||||
} else { // append
|
||||
// idx_end + 1 may equal lines.size() when line_start == -1 (end of file)
|
||||
lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end());
|
||||
}
|
||||
}
|
||||
|
||||
// write file back
|
||||
std::ofstream fout(path, std::ios::binary);
|
||||
if (!fout) {
|
||||
return {{"error", "failed to open file for writing: " + path}};
|
||||
}
|
||||
for (size_t i = 0; i < lines.size(); i++) {
|
||||
fout << lines[i];
|
||||
if (i + 1 < lines.size()) {
|
||||
fout << "\n";
|
||||
}
|
||||
}
|
||||
if (!lines.empty()) {
|
||||
fout << "\n";
|
||||
}
|
||||
if (!fout) {
|
||||
return {{"error", "failed to write file: " + path}};
|
||||
}
|
||||
|
||||
return {{"result", "file edited successfully"}, {"path", path}, {"lines", (int) lines.size()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// apply_diff: apply a unified diff via git apply
|
||||
//
|
||||
|
||||
struct server_tool_apply_diff : server_tool {
|
||||
server_tool_apply_diff() {
|
||||
name = "apply_diff";
|
||||
display_name = "Apply diff";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Apply a unified diff to edit one or more files using git apply. Use this instead of edit_file when the changes are complex."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"diff", {{"type", "string"}, {"description", "Unified diff content in git diff format"}}},
|
||||
}},
|
||||
{"required", json::array({"diff"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string diff = params.at("diff").get<std::string>();
|
||||
|
||||
// write diff to a temporary file
|
||||
static std::atomic<int> counter{0};
|
||||
std::string tmp_path = (fs::temp_directory_path() /
|
||||
("llama_patch_" + std::to_string(++counter) + ".patch")).string();
|
||||
|
||||
{
|
||||
std::ofstream f(tmp_path, std::ios::binary);
|
||||
if (!f) {
|
||||
return {{"error", "failed to create temp patch file"}};
|
||||
}
|
||||
f << diff;
|
||||
}
|
||||
|
||||
auto res = run_process({"git", "apply", tmp_path}, 4096, 10);
|
||||
|
||||
std::error_code ec;
|
||||
fs::remove(tmp_path, ec);
|
||||
|
||||
if (res.exit_code != 0) {
|
||||
return {{"error", "git apply failed (exit " + std::to_string(res.exit_code) + "): " + res.output}};
|
||||
}
|
||||
return {{"result", "patch applied successfully"}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// public API
|
||||
//
|
||||
|
||||
static std::vector<std::unique_ptr<server_tool>> build_tools() {
|
||||
std::vector<std::unique_ptr<server_tool>> tools;
|
||||
tools.push_back(std::make_unique<server_tool_read_file>());
|
||||
tools.push_back(std::make_unique<server_tool_file_glob_search>());
|
||||
tools.push_back(std::make_unique<server_tool_grep_search>());
|
||||
tools.push_back(std::make_unique<server_tool_exec_shell_command>());
|
||||
tools.push_back(std::make_unique<server_tool_write_file>());
|
||||
tools.push_back(std::make_unique<server_tool_edit_file>());
|
||||
tools.push_back(std::make_unique<server_tool_apply_diff>());
|
||||
return tools;
|
||||
}
|
||||
|
||||
void server_tools::setup(const std::vector<std::string> & enabled_tools) {
|
||||
if (!enabled_tools.empty()) {
|
||||
std::unordered_set<std::string> enabled_set(enabled_tools.begin(), enabled_tools.end());
|
||||
auto all_tools = build_tools();
|
||||
|
||||
tools.clear();
|
||||
for (auto & t : all_tools) {
|
||||
if (enabled_set.count(t->name) > 0 || enabled_set.count("all") > 0) {
|
||||
tools.push_back(std::move(t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handle_get = [this](const server_http_req &) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json result = json::array();
|
||||
for (const auto & t : tools) {
|
||||
result.push_back(t->to_json());
|
||||
}
|
||||
res->data = safe_json_to_str(result);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
handle_post = [this](const server_http_req & req) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json body = json::parse(req.body);
|
||||
std::string tool_name = body.at("tool").get<std::string>();
|
||||
json params = body.value("params", json::object());
|
||||
json result = invoke(tool_name, params);
|
||||
res->data = safe_json_to_str(result);
|
||||
} catch (const json::exception & e) {
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
json server_tools::invoke(const std::string & name, const json & params) {
|
||||
for (auto & t : tools) {
|
||||
if (t->name == name) {
|
||||
return t->invoke(params);
|
||||
}
|
||||
}
|
||||
return {{"error", "unknown tool: " + name}};
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
|
||||
struct server_tool {
|
||||
std::string name;
|
||||
std::string display_name;
|
||||
bool permission_write = false;
|
||||
|
||||
virtual ~server_tool() = default;
|
||||
virtual json get_definition() = 0;
|
||||
virtual json invoke(json params) = 0;
|
||||
|
||||
json to_json();
|
||||
};
|
||||
|
||||
struct server_tools {
|
||||
std::vector<std::unique_ptr<server_tool>> tools;
|
||||
|
||||
void setup(const std::vector<std::string> & enabled_tools);
|
||||
json invoke(const std::string & name, const json & params);
|
||||
|
||||
server_http_context::handler_t handle_get;
|
||||
server_http_context::handler_t handle_post;
|
||||
};
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "server-http.h"
|
||||
#include "server-models.h"
|
||||
#include "server-cors-proxy.h"
|
||||
#include "server-tools.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
@@ -124,6 +125,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// register API routes
|
||||
server_routes routes(params, ctx_server);
|
||||
server_tools tools;
|
||||
|
||||
bool is_router_server = params.model.path.empty();
|
||||
std::optional<server_models_routes> models_routes{};
|
||||
@@ -211,6 +213,16 @@ int main(int argc, char ** argv) {
|
||||
ctx_http.get ("/cors-proxy", ex_wrapper(proxy_handler_get));
|
||||
ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post));
|
||||
}
|
||||
// EXPERIMENTAL built-in tools
|
||||
if (!params.server_tools.empty()) {
|
||||
tools.setup(params.server_tools);
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
SRV_WRN("%s", "Built-in tools are enabled, do not expose server to untrusted environments\n");
|
||||
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be changed in the future\n");
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
ctx_http.get ("/tools", ex_wrapper(tools.handle_get));
|
||||
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Svelte action that fades in an element when it enters the viewport.
|
||||
* Uses IntersectionObserver for efficient viewport detection.
|
||||
*
|
||||
* If skipIfVisible is set and the element is already visible in the viewport
|
||||
* when the action attaches (e.g. a markdown block promoted from unstable
|
||||
* during streaming), the fade is skipped entirely to avoid a flash.
|
||||
*/
|
||||
export function fadeInView(
|
||||
node: HTMLElement,
|
||||
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
|
||||
) {
|
||||
const { duration = 300, y = 0, skipIfVisible = false } = options;
|
||||
|
||||
if (skipIfVisible) {
|
||||
const rect = node.getBoundingClientRect();
|
||||
const isAlreadyVisible =
|
||||
rect.top < window.innerHeight &&
|
||||
rect.bottom > 0 &&
|
||||
rect.left < window.innerWidth &&
|
||||
rect.right > 0;
|
||||
|
||||
if (isAlreadyVisible) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
node.style.opacity = '0';
|
||||
node.style.transform = `translateY(${y}px)`;
|
||||
node.style.transition = `opacity ${duration}ms ease-out, transform ${duration}ms ease-out`;
|
||||
|
||||
$effect(() => {
|
||||
const observer = new IntersectionObserver(
|
||||
(entries) => {
|
||||
for (const entry of entries) {
|
||||
if (entry.isIntersecting) {
|
||||
requestAnimationFrame(() => {
|
||||
node.style.opacity = '1';
|
||||
node.style.transform = 'translateY(0)';
|
||||
});
|
||||
observer.disconnect();
|
||||
}
|
||||
}
|
||||
},
|
||||
{ threshold: 0.05 }
|
||||
);
|
||||
|
||||
observer.observe(node);
|
||||
|
||||
return () => {
|
||||
observer.disconnect();
|
||||
};
|
||||
});
|
||||
}
|
||||
+1
-11
@@ -3,14 +3,12 @@
|
||||
ChatMessageAgenticContent,
|
||||
ChatMessageActions,
|
||||
ChatMessageStatistics,
|
||||
MarkdownContent,
|
||||
ModelBadge,
|
||||
ModelsSelector
|
||||
} from '$lib/components/app';
|
||||
import { getMessageEditContext } from '$lib/contexts';
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
|
||||
import { agenticStreamingToolCall } from '$lib/stores/agentic.svelte';
|
||||
import { autoResizeTextarea, copyToClipboard, isIMEComposing } from '$lib/utils';
|
||||
import { tick } from 'svelte';
|
||||
import { fade } from 'svelte/transition';
|
||||
@@ -87,13 +85,7 @@
|
||||
const hasAgenticMarkers = $derived(
|
||||
messageContent?.includes(AGENTIC_TAGS.TOOL_CALL_START) ?? false
|
||||
);
|
||||
const hasStreamingToolCall = $derived(
|
||||
isChatStreaming() && agenticStreamingToolCall(message.convId) !== null
|
||||
);
|
||||
const hasReasoningMarkers = $derived(messageContent?.includes(REASONING_TAGS.START) ?? false);
|
||||
const isStructuredContent = $derived(
|
||||
hasAgenticMarkers || hasReasoningMarkers || hasStreamingToolCall
|
||||
);
|
||||
const processingState = useProcessingState();
|
||||
|
||||
let currentConfig = $derived(config());
|
||||
@@ -256,15 +248,13 @@
|
||||
{:else if message.role === MessageRole.ASSISTANT}
|
||||
{#if showRawOutput}
|
||||
<pre class="raw-output">{messageContent || ''}</pre>
|
||||
{:else if isStructuredContent}
|
||||
{:else}
|
||||
<ChatMessageAgenticContent
|
||||
content={messageContent || ''}
|
||||
isStreaming={isChatStreaming()}
|
||||
highlightTurns={highlightAgenticTurns}
|
||||
{message}
|
||||
/>
|
||||
{:else}
|
||||
<MarkdownContent content={messageContent || ''} attachments={message.extra} />
|
||||
{/if}
|
||||
{:else}
|
||||
<div class="text-sm whitespace-pre-wrap">
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import { ChatMessage } from '$lib/components/app';
|
||||
import { setChatActionsContext } from '$lib/contexts';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
@@ -140,13 +141,18 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="flex h-full flex-col space-y-10 pt-24 {className}" style="height: auto; ">
|
||||
<div
|
||||
class="flex h-full flex-col space-y-10 pt-24 {className}"
|
||||
style="height: auto; min-height: calc(100dvh - 14rem);"
|
||||
>
|
||||
{#each displayMessages as { message, isLastAssistantMessage, siblingInfo } (message.id)}
|
||||
<ChatMessage
|
||||
class="mx-auto w-full max-w-[48rem]"
|
||||
{message}
|
||||
{isLastAssistantMessage}
|
||||
{siblingInfo}
|
||||
/>
|
||||
<div use:fadeInView>
|
||||
<ChatMessage
|
||||
class="mx-auto w-full max-w-[48rem]"
|
||||
{message}
|
||||
{isLastAssistantMessage}
|
||||
{siblingInfo}
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
} from '$lib/components/app';
|
||||
import * as Alert from '$lib/components/ui/alert';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { INITIAL_SCROLL_DELAY } from '$lib/constants';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import {
|
||||
@@ -48,7 +47,7 @@
|
||||
let showFileErrorDialog = $state(false);
|
||||
let uploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
|
||||
const autoScroll = createAutoScrollController();
|
||||
const autoScroll = createAutoScrollController({ isColumnReverse: true });
|
||||
|
||||
let fileErrorData = $state<{
|
||||
generallyUnsupported: File[];
|
||||
@@ -310,13 +309,15 @@
|
||||
|
||||
afterNavigate(() => {
|
||||
if (!disableAutoScroll) {
|
||||
setTimeout(() => autoScroll.scrollToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
autoScroll.enable();
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
autoScroll.startObserving();
|
||||
|
||||
if (!disableAutoScroll) {
|
||||
setTimeout(() => autoScroll.scrollToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
autoScroll.enable();
|
||||
}
|
||||
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
@@ -333,10 +334,6 @@
|
||||
$effect(() => {
|
||||
autoScroll.setDisabled(disableAutoScroll);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.updateInterval(isCurrentConversationLoading);
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if isDragOver}
|
||||
@@ -351,7 +348,7 @@
|
||||
<div
|
||||
bind:this={chatScrollContainer}
|
||||
aria-label="Chat interface with file drop zone"
|
||||
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
|
||||
class="flex h-full flex-col-reverse overflow-y-auto px-4 md:px-6"
|
||||
ondragenter={handleDragEnter}
|
||||
ondragleave={handleDragLeave}
|
||||
ondragover={handleDragOver}
|
||||
@@ -359,57 +356,59 @@
|
||||
onscroll={handleScroll}
|
||||
role="main"
|
||||
>
|
||||
<ChatMessages
|
||||
class="mb-16 md:mb-24"
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
autoScroll.scrollToBottom();
|
||||
}}
|
||||
/>
|
||||
<div class="flex flex-col">
|
||||
<ChatMessages
|
||||
class="mb-16 md:mb-24"
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
autoScroll.scrollToBottom();
|
||||
}}
|
||||
/>
|
||||
|
||||
<div
|
||||
class="pointer-events-none sticky right-0 bottom-4 left-0 mt-auto"
|
||||
in:slide={{ duration: 150, axis: 'y' }}
|
||||
>
|
||||
<ChatScreenProcessingInfo />
|
||||
<div
|
||||
class="pointer-events-none sticky right-0 bottom-4 left-0 mt-auto"
|
||||
in:slide={{ duration: 150, axis: 'y' }}
|
||||
>
|
||||
<ChatScreenProcessingInfo />
|
||||
|
||||
{#if hasPropsError}
|
||||
<div
|
||||
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
|
||||
in:fly={{ y: 10, duration: 250 }}
|
||||
>
|
||||
<Alert.Root variant="destructive">
|
||||
<AlertTriangle class="h-4 w-4" />
|
||||
<Alert.Title class="flex items-center justify-between">
|
||||
<span>Server unavailable</span>
|
||||
<button
|
||||
onclick={() => serverStore.fetch()}
|
||||
disabled={isServerLoading}
|
||||
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
|
||||
>
|
||||
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
|
||||
{isServerLoading ? 'Retrying...' : 'Retry'}
|
||||
</button>
|
||||
</Alert.Title>
|
||||
<Alert.Description>{serverError()}</Alert.Description>
|
||||
</Alert.Root>
|
||||
{#if hasPropsError}
|
||||
<div
|
||||
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
|
||||
in:fly={{ y: 10, duration: 250 }}
|
||||
>
|
||||
<Alert.Root variant="destructive">
|
||||
<AlertTriangle class="h-4 w-4" />
|
||||
<Alert.Title class="flex items-center justify-between">
|
||||
<span>Server unavailable</span>
|
||||
<button
|
||||
onclick={() => serverStore.fetch()}
|
||||
disabled={isServerLoading}
|
||||
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
|
||||
>
|
||||
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
|
||||
{isServerLoading ? 'Retrying...' : 'Retry'}
|
||||
</button>
|
||||
</Alert.Title>
|
||||
<Alert.Description>{serverError()}</Alert.Description>
|
||||
</Alert.Root>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
|
||||
<ChatScreenForm
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
showHelperText={false}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
|
||||
<ChatScreenForm
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
showHelperText={false}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -296,6 +296,11 @@
|
||||
label: 'Disable reasoning content parsing',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.EXCLUDE_REASONING_FROM_CONTEXT,
|
||||
label: 'Exclude reasoning from context',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.SHOW_RAW_OUTPUT_SWITCH,
|
||||
label: 'Enable raw output toggle',
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import type { DatabaseMessageExtra } from '$lib/types/database';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
|
||||
interface Props {
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
@@ -598,7 +599,7 @@
|
||||
: ''}"
|
||||
>
|
||||
{#each renderedBlocks as block (block.id)}
|
||||
<div class="markdown-block" data-block-id={block.id}>
|
||||
<div class="markdown-block" data-block-id={block.id} use:fadeInView={{ skipIfVisible: true }}>
|
||||
<!-- eslint-disable-next-line no-at-html-tags -->
|
||||
{@html block.html}
|
||||
</div>
|
||||
@@ -651,7 +652,6 @@
|
||||
/>
|
||||
|
||||
<style>
|
||||
.markdown-block,
|
||||
.markdown-block--unstable {
|
||||
display: contents;
|
||||
}
|
||||
|
||||
@@ -50,6 +50,8 @@ export const AGENTIC_REGEX = {
|
||||
PARTIAL_MARKER: /<<<[A-Za-z_]*$/,
|
||||
// Matches reasoning content blocks (including tags)
|
||||
REASONING_BLOCK: /<<<reasoning_content_start>>>[\s\S]*?<<<reasoning_content_end>>>/g,
|
||||
// Captures the reasoning text between start/end tags
|
||||
REASONING_EXTRACT: /<<<reasoning_content_start>>>([\s\S]*?)<<<reasoning_content_end>>>/,
|
||||
// Matches an opening reasoning tag and any remaining content (unterminated)
|
||||
REASONING_OPEN: /<<<reasoning_content_start>>>[\s\S]*$/,
|
||||
// Matches a complete agentic tool call display block (start to end marker)
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
export const AUTO_SCROLL_INTERVAL = 100;
|
||||
export const INITIAL_SCROLL_DELAY = 50;
|
||||
export const AUTO_SCROLL_AT_BOTTOM_THRESHOLD = 10;
|
||||
|
||||
@@ -10,6 +10,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean |
|
||||
theme: ColorMode.SYSTEM,
|
||||
showThoughtInProgress: false,
|
||||
disableReasoningParsing: false,
|
||||
excludeReasoningFromContext: false,
|
||||
showRawOutputSwitch: false,
|
||||
keepStatsVisible: false,
|
||||
showMessageStats: true,
|
||||
@@ -106,6 +107,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
showThoughtInProgress: 'Expand thought process by default when generating messages.',
|
||||
disableReasoningParsing:
|
||||
'Send reasoning_format=none to prevent server-side extraction of reasoning tokens into separate field',
|
||||
excludeReasoningFromContext:
|
||||
'Strip reasoning content from previous messages before sending to the model. When unchecked, reasoning is sent back via the reasoning_content field so the model can see its own chain-of-thought across turns.',
|
||||
showRawOutputSwitch:
|
||||
'Show toggle button to display messages as plain text instead of Markdown-formatted content',
|
||||
keepStatsVisible: 'Keep processing statistics visible after generation finishes.',
|
||||
|
||||
@@ -54,6 +54,7 @@ export const SETTINGS_KEYS = {
|
||||
SHOW_TOOL_CALL_IN_PROGRESS: 'showToolCallInProgress',
|
||||
// Developer
|
||||
DISABLE_REASONING_PARSING: 'disableReasoningParsing',
|
||||
EXCLUDE_REASONING_FROM_CONTEXT: 'excludeReasoningFromContext',
|
||||
SHOW_RAW_OUTPUT_SWITCH: 'showRawOutputSwitch',
|
||||
CUSTOM: 'custom'
|
||||
} as const;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { AUTO_SCROLL_AT_BOTTOM_THRESHOLD, AUTO_SCROLL_INTERVAL } from '$lib/constants';
|
||||
|
||||
export interface AutoScrollOptions {
|
||||
/** Whether auto-scroll is disabled globally (e.g., from settings) */
|
||||
disabled?: boolean;
|
||||
isColumnReverse?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -12,6 +12,7 @@ export interface AutoScrollOptions {
|
||||
* - Auto-scrolls to bottom during streaming/loading
|
||||
* - Stops auto-scroll when user manually scrolls up
|
||||
* - Resumes auto-scroll when user scrolls back to bottom
|
||||
* - Supports both normal and column-reverse scroll containers
|
||||
*/
|
||||
export class AutoScrollController {
|
||||
private _autoScrollEnabled = $state(true);
|
||||
@@ -21,9 +22,14 @@ export class AutoScrollController {
|
||||
private _scrollTimeout: ReturnType<typeof setTimeout> | undefined;
|
||||
private _container: HTMLElement | undefined;
|
||||
private _disabled: boolean;
|
||||
private _isColumnReverse: boolean;
|
||||
private _mutationObserver: MutationObserver | null = null;
|
||||
private _rafPending = false;
|
||||
private _observerEnabled = false;
|
||||
|
||||
constructor(options: AutoScrollOptions = {}) {
|
||||
this._disabled = options.disabled ?? false;
|
||||
this._isColumnReverse = options.isColumnReverse ?? false;
|
||||
}
|
||||
|
||||
get autoScrollEnabled(): boolean {
|
||||
@@ -38,7 +44,12 @@ export class AutoScrollController {
|
||||
* Binds the controller to a scrollable container element.
|
||||
*/
|
||||
setContainer(container: HTMLElement | undefined): void {
|
||||
this._doStopObserving();
|
||||
this._container = container;
|
||||
|
||||
if (this._observerEnabled && container && !this._disabled) {
|
||||
this._doStartObserving();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -49,6 +60,9 @@ export class AutoScrollController {
|
||||
if (disabled) {
|
||||
this._autoScrollEnabled = false;
|
||||
this.stopInterval();
|
||||
this._doStopObserving();
|
||||
} else if (this._observerEnabled && this._container && !this._mutationObserver) {
|
||||
this._doStartObserving();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,10 +73,23 @@ export class AutoScrollController {
|
||||
if (this._disabled || !this._container) return;
|
||||
|
||||
const { scrollTop, scrollHeight, clientHeight } = this._container;
|
||||
const distanceFromBottom = scrollHeight - scrollTop - clientHeight;
|
||||
|
||||
let distanceFromBottom: number;
|
||||
let isScrollingUp: boolean;
|
||||
|
||||
if (this._isColumnReverse) {
|
||||
// column-reverse: scrollTop=0 at bottom, negative when scrolled up
|
||||
distanceFromBottom = Math.abs(scrollTop);
|
||||
isScrollingUp = scrollTop < this._lastScrollTop;
|
||||
} else {
|
||||
// normal: scrollTop=0 at top, increases when scrolled down
|
||||
distanceFromBottom = scrollHeight - clientHeight - scrollTop;
|
||||
isScrollingUp = scrollTop < this._lastScrollTop;
|
||||
}
|
||||
|
||||
const isAtBottom = distanceFromBottom < AUTO_SCROLL_AT_BOTTOM_THRESHOLD;
|
||||
|
||||
if (scrollTop < this._lastScrollTop && !isAtBottom) {
|
||||
if (isScrollingUp && !isAtBottom) {
|
||||
this._userScrolledUp = true;
|
||||
this._autoScrollEnabled = false;
|
||||
} else if (isAtBottom && this._userScrolledUp) {
|
||||
@@ -90,10 +117,12 @@ export class AutoScrollController {
|
||||
scrollToBottom(behavior: ScrollBehavior = 'smooth'): void {
|
||||
if (this._disabled || !this._container) return;
|
||||
|
||||
this._container.scrollTo({
|
||||
top: this._container.scrollHeight,
|
||||
behavior
|
||||
});
|
||||
if (this._isColumnReverse) {
|
||||
// column-reverse: scrollTop=0 is the bottom
|
||||
this._container.scrollTo({ top: 0, behavior });
|
||||
} else {
|
||||
this._container.scrollTo({ top: this._container.scrollHeight, behavior });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -150,11 +179,69 @@ export class AutoScrollController {
|
||||
*/
|
||||
destroy(): void {
|
||||
this.stopInterval();
|
||||
this._doStopObserving();
|
||||
|
||||
if (this._scrollTimeout) {
|
||||
clearTimeout(this._scrollTimeout);
|
||||
this._scrollTimeout = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a MutationObserver on the container that auto-scrolls to bottom
|
||||
* on content changes. More responsive than interval-based polling.
|
||||
*/
|
||||
startObserving(): void {
|
||||
this._observerEnabled = true;
|
||||
|
||||
if (this._container && !this._disabled && !this._mutationObserver) {
|
||||
this._doStartObserving();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops the MutationObserver.
|
||||
*/
|
||||
stopObserving(): void {
|
||||
this._observerEnabled = false;
|
||||
this._doStopObserving();
|
||||
}
|
||||
|
||||
private _doStartObserving(): void {
|
||||
if (!this._container || this._mutationObserver) return;
|
||||
|
||||
const isReverse = this._isColumnReverse;
|
||||
|
||||
this._mutationObserver = new MutationObserver(() => {
|
||||
if (!this._autoScrollEnabled || this._rafPending) return;
|
||||
this._rafPending = true;
|
||||
requestAnimationFrame(() => {
|
||||
this._rafPending = false;
|
||||
if (this._autoScrollEnabled && this._container) {
|
||||
if (isReverse) {
|
||||
// column-reverse: scrollTop=0 is the bottom
|
||||
this._container.scrollTop = 0;
|
||||
} else {
|
||||
this._container.scrollTop = this._container.scrollHeight;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
this._mutationObserver.observe(this._container, {
|
||||
childList: true,
|
||||
subtree: true,
|
||||
characterData: true
|
||||
});
|
||||
}
|
||||
|
||||
private _doStopObserving(): void {
|
||||
if (this._mutationObserver) {
|
||||
this._mutationObserver.disconnect();
|
||||
this._mutationObserver = null;
|
||||
}
|
||||
this._rafPending = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -57,6 +57,46 @@ export class ChatService {
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Extracts reasoning text from content that contains internal reasoning tags.
|
||||
* Returns the concatenated reasoning content or undefined if none found.
|
||||
*/
|
||||
private static extractReasoningFromContent(
|
||||
content: ApiChatMessageData['content'] | null | undefined
|
||||
): string | undefined {
|
||||
if (!content) return undefined;
|
||||
|
||||
const extractFromString = (text: string): string => {
|
||||
const parts: string[] = [];
|
||||
// Use a fresh regex instance to avoid shared lastIndex state
|
||||
const re = new RegExp(AGENTIC_REGEX.REASONING_EXTRACT.source);
|
||||
let match = re.exec(text);
|
||||
while (match) {
|
||||
parts.push(match[1]);
|
||||
// advance past the matched portion and retry
|
||||
text = text.slice(match.index + match[0].length);
|
||||
match = re.exec(text);
|
||||
}
|
||||
return parts.join('');
|
||||
};
|
||||
|
||||
if (typeof content === 'string') {
|
||||
const result = extractFromString(content);
|
||||
return result || undefined;
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) return undefined;
|
||||
|
||||
const parts: string[] = [];
|
||||
for (const part of content) {
|
||||
if (part.type === ContentPartType.TEXT && part.text) {
|
||||
const result = extractFromString(part.text);
|
||||
if (result) parts.push(result);
|
||||
}
|
||||
}
|
||||
return parts.length > 0 ? parts.join('') : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a chat completion request to the llama.cpp server.
|
||||
* Supports both streaming and non-streaming responses with comprehensive parameter configuration.
|
||||
@@ -111,7 +151,8 @@ export class ChatService {
|
||||
custom,
|
||||
timings_per_token,
|
||||
// Config options
|
||||
disableReasoningParsing
|
||||
disableReasoningParsing,
|
||||
excludeReasoningFromContext
|
||||
} = options;
|
||||
|
||||
const normalizedMessages: ApiChatMessageData[] = messages
|
||||
@@ -159,14 +200,24 @@ export class ChatService {
|
||||
}
|
||||
|
||||
const requestBody: ApiChatCompletionRequest = {
|
||||
messages: normalizedMessages.map((msg: ApiChatMessageData) => ({
|
||||
role: msg.role,
|
||||
// Strip reasoning tags/content from the prompt to avoid polluting KV cache.
|
||||
// TODO: investigate backend expectations for reasoning tags and add a toggle if needed.
|
||||
content: ChatService.stripReasoningContent(msg.content),
|
||||
tool_calls: msg.tool_calls,
|
||||
tool_call_id: msg.tool_call_id
|
||||
})),
|
||||
messages: normalizedMessages.map((msg: ApiChatMessageData) => {
|
||||
// Always strip internal reasoning/agentic tags from content
|
||||
const cleanedContent = ChatService.stripReasoningContent(msg.content);
|
||||
const mapped: ApiChatCompletionRequest['messages'][0] = {
|
||||
role: msg.role,
|
||||
content: cleanedContent,
|
||||
tool_calls: msg.tool_calls,
|
||||
tool_call_id: msg.tool_call_id
|
||||
};
|
||||
// When preserving reasoning, extract it from raw content and send as separate field
|
||||
if (!excludeReasoningFromContext) {
|
||||
const reasoning = ChatService.extractReasoningFromContent(msg.content);
|
||||
if (reasoning) {
|
||||
mapped.reasoning_content = reasoning;
|
||||
}
|
||||
}
|
||||
return mapped;
|
||||
}),
|
||||
stream,
|
||||
return_progress: stream ? true : undefined,
|
||||
tools: tools && tools.length > 0 ? tools : undefined
|
||||
|
||||
@@ -227,6 +227,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
||||
serverKey: 'alwaysShowAgenticTurns',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'excludeReasoningFromContext',
|
||||
serverKey: 'excludeReasoningFromContext',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
}
|
||||
];
|
||||
|
||||
|
||||
@@ -1479,6 +1479,8 @@ class ChatStore {
|
||||
|
||||
if (currentConfig.disableReasoningParsing) apiOptions.disableReasoningParsing = true;
|
||||
|
||||
if (currentConfig.excludeReasoningFromContext) apiOptions.excludeReasoningFromContext = true;
|
||||
|
||||
if (hasValue(currentConfig.temperature))
|
||||
apiOptions.temperature = Number(currentConfig.temperature);
|
||||
|
||||
|
||||
+4
@@ -45,6 +45,7 @@ export interface ApiErrorResponse {
|
||||
export interface ApiChatMessageData {
|
||||
role: ChatRole;
|
||||
content: string | ApiChatMessageContentPart[];
|
||||
reasoning_content?: string;
|
||||
tool_calls?: ApiChatCompletionToolCall[];
|
||||
tool_call_id?: string;
|
||||
timestamp?: number;
|
||||
@@ -201,6 +202,9 @@ export interface ApiChatCompletionRequest {
|
||||
messages: Array<{
|
||||
role: ChatRole;
|
||||
content: string | ApiChatMessageContentPart[];
|
||||
reasoning_content?: string;
|
||||
tool_calls?: ApiChatCompletionToolCall[];
|
||||
tool_call_id?: string;
|
||||
}>;
|
||||
stream?: boolean;
|
||||
model?: string;
|
||||
|
||||
@@ -24,6 +24,8 @@ export interface SettingsChatServiceOptions {
|
||||
systemMessage?: string;
|
||||
// Disable reasoning parsing (use 'none' instead of 'auto')
|
||||
disableReasoningParsing?: boolean;
|
||||
// Strip reasoning content from context before sending
|
||||
excludeReasoningFromContext?: boolean;
|
||||
tools?: OpenAIToolDefinition[];
|
||||
// Generation parameters
|
||||
temperature?: number;
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { AGENTIC_REGEX, REASONING_TAGS } from '$lib/constants/agentic';
|
||||
import { ContentPartType } from '$lib/enums';
|
||||
|
||||
// Replicate ChatService.extractReasoningFromContent (private static)
|
||||
function extractReasoningFromContent(
|
||||
content: string | Array<{ type: string; text?: string }> | null | undefined
|
||||
): string | undefined {
|
||||
if (!content) return undefined;
|
||||
|
||||
const extractFromString = (text: string): string => {
|
||||
const parts: string[] = [];
|
||||
const re = new RegExp(AGENTIC_REGEX.REASONING_EXTRACT.source);
|
||||
let match = re.exec(text);
|
||||
while (match) {
|
||||
parts.push(match[1]);
|
||||
text = text.slice(match.index + match[0].length);
|
||||
match = re.exec(text);
|
||||
}
|
||||
return parts.join('');
|
||||
};
|
||||
|
||||
if (typeof content === 'string') {
|
||||
const result = extractFromString(content);
|
||||
return result || undefined;
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) return undefined;
|
||||
|
||||
const parts: string[] = [];
|
||||
for (const part of content) {
|
||||
if (part.type === ContentPartType.TEXT && part.text) {
|
||||
const result = extractFromString(part.text);
|
||||
if (result) parts.push(result);
|
||||
}
|
||||
}
|
||||
return parts.length > 0 ? parts.join('') : undefined;
|
||||
}
|
||||
|
||||
// Replicate ChatService.stripReasoningContent (private static)
|
||||
function stripReasoningContent(
|
||||
content: string | Array<{ type: string; text?: string }> | null | undefined
|
||||
): typeof content {
|
||||
if (!content) return content;
|
||||
|
||||
if (typeof content === 'string') {
|
||||
return content
|
||||
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '');
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) return content;
|
||||
|
||||
return content.map((part) => {
|
||||
if (part.type !== ContentPartType.TEXT || !part.text) return part;
|
||||
return {
|
||||
...part,
|
||||
text: part.text
|
||||
.replace(AGENTIC_REGEX.REASONING_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.REASONING_OPEN, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_BLOCK, '')
|
||||
.replace(AGENTIC_REGEX.AGENTIC_TOOL_CALL_OPEN, '')
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// Simulate the message mapping logic from ChatService.sendMessage
|
||||
function buildApiMessage(
|
||||
content: string,
|
||||
excludeReasoningFromContext: boolean
|
||||
): { role: string; content: string; reasoning_content?: string } {
|
||||
const cleaned = stripReasoningContent(content) as string;
|
||||
const mapped: { role: string; content: string; reasoning_content?: string } = {
|
||||
role: 'assistant',
|
||||
content: cleaned
|
||||
};
|
||||
if (!excludeReasoningFromContext) {
|
||||
const reasoning = extractReasoningFromContent(content);
|
||||
if (reasoning) {
|
||||
mapped.reasoning_content = reasoning;
|
||||
}
|
||||
}
|
||||
return mapped;
|
||||
}
|
||||
|
||||
// Helper: wrap reasoning the same way the chat store does during streaming
|
||||
function wrapReasoning(reasoning: string, content: string): string {
|
||||
return `${REASONING_TAGS.START}${reasoning}${REASONING_TAGS.END}${content}`;
|
||||
}
|
||||
|
||||
describe('reasoning content extraction', () => {
|
||||
it('extracts reasoning from tagged string content', () => {
|
||||
const input = wrapReasoning('step 1, step 2', 'The answer is 42.');
|
||||
const result = extractReasoningFromContent(input);
|
||||
expect(result).toBe('step 1, step 2');
|
||||
});
|
||||
|
||||
it('returns undefined when no reasoning tags present', () => {
|
||||
expect(extractReasoningFromContent('Just a normal response.')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns undefined for null/empty input', () => {
|
||||
expect(extractReasoningFromContent(null)).toBeUndefined();
|
||||
expect(extractReasoningFromContent(undefined)).toBeUndefined();
|
||||
expect(extractReasoningFromContent('')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('extracts reasoning from content part arrays', () => {
|
||||
const input = [
|
||||
{
|
||||
type: ContentPartType.TEXT,
|
||||
text: wrapReasoning('thinking hard', 'result')
|
||||
}
|
||||
];
|
||||
expect(extractReasoningFromContent(input)).toBe('thinking hard');
|
||||
});
|
||||
|
||||
it('handles multiple reasoning blocks', () => {
|
||||
const input =
|
||||
REASONING_TAGS.START +
|
||||
'block1' +
|
||||
REASONING_TAGS.END +
|
||||
'middle' +
|
||||
REASONING_TAGS.START +
|
||||
'block2' +
|
||||
REASONING_TAGS.END +
|
||||
'end';
|
||||
expect(extractReasoningFromContent(input)).toBe('block1block2');
|
||||
});
|
||||
|
||||
it('ignores non-text content parts', () => {
|
||||
const input = [{ type: 'image_url', text: wrapReasoning('hidden', 'img') }];
|
||||
expect(extractReasoningFromContent(input)).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('strip reasoning content', () => {
|
||||
it('removes reasoning tags from string content', () => {
|
||||
const input = wrapReasoning('internal thoughts', 'visible answer');
|
||||
expect(stripReasoningContent(input)).toBe('visible answer');
|
||||
});
|
||||
|
||||
it('removes reasoning from content part arrays', () => {
|
||||
const input = [
|
||||
{
|
||||
type: ContentPartType.TEXT,
|
||||
text: wrapReasoning('thoughts', 'answer')
|
||||
}
|
||||
];
|
||||
const result = stripReasoningContent(input) as Array<{ type: string; text?: string }>;
|
||||
expect(result[0].text).toBe('answer');
|
||||
});
|
||||
});
|
||||
|
||||
describe('API message building with reasoning preservation', () => {
|
||||
const storedContent = wrapReasoning('Let me think: 2+2=4, basic arithmetic.', 'The answer is 4.');
|
||||
|
||||
it('preserves reasoning_content when excludeReasoningFromContext is false', () => {
|
||||
const msg = buildApiMessage(storedContent, false);
|
||||
expect(msg.content).toBe('The answer is 4.');
|
||||
expect(msg.reasoning_content).toBe('Let me think: 2+2=4, basic arithmetic.');
|
||||
// no internal tags leak into either field
|
||||
expect(msg.content).not.toContain('<<<');
|
||||
expect(msg.reasoning_content).not.toContain('<<<');
|
||||
});
|
||||
|
||||
it('strips reasoning_content when excludeReasoningFromContext is true', () => {
|
||||
const msg = buildApiMessage(storedContent, true);
|
||||
expect(msg.content).toBe('The answer is 4.');
|
||||
expect(msg.reasoning_content).toBeUndefined();
|
||||
});
|
||||
|
||||
it('handles content with no reasoning in both modes', () => {
|
||||
const plain = 'No reasoning here.';
|
||||
const msgPreserve = buildApiMessage(plain, false);
|
||||
const msgExclude = buildApiMessage(plain, true);
|
||||
expect(msgPreserve.content).toBe(plain);
|
||||
expect(msgPreserve.reasoning_content).toBeUndefined();
|
||||
expect(msgExclude.content).toBe(plain);
|
||||
expect(msgExclude.reasoning_content).toBeUndefined();
|
||||
});
|
||||
|
||||
it('cleans agentic tool call blocks from content even when preserving reasoning', () => {
|
||||
const input =
|
||||
wrapReasoning('plan', 'text') +
|
||||
'\n\n<<<AGENTIC_TOOL_CALL_START>>>\n' +
|
||||
'<<<TOOL_NAME:bash>>>\n' +
|
||||
'<<<TOOL_ARGS_START>>>\n{}\n<<<TOOL_ARGS_END>>>\nout\n' +
|
||||
'<<<AGENTIC_TOOL_CALL_END>>>\n';
|
||||
const msg = buildApiMessage(input, false);
|
||||
expect(msg.content).not.toContain('<<<');
|
||||
expect(msg.reasoning_content).toBe('plan');
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user